cache, goroutine and unbounded workers management (#6420)

* Enhance matcher compilation with caching for regex and DSL expressions to improve performance. Update template parsing to conditionally retain raw templates based on size constraints.

* Implement caching for regex and DSL expressions in extractors and matchers to enhance performance. Introduce a buffer pool in raw requests to reduce memory allocations. Update template cache management for improved efficiency.

* feat: improve concurrency to be bound

* refactor: replace fmt.Sprintf with fmt.Fprintf for improved performance in header handling

* feat: add regex matching tests and benchmarks for performance evaluation

* feat: add prefix check in regex extraction to optimize matching process

* feat: implement regex caching mechanism to enhance performance in extractors and matchers, along with tests and benchmarks for validation

* feat: add unit tests for template execution in the core engine, enhancing test coverage and reliability

* feat: enhance error handling in template execution and improve regex caching logic for better performance

* Implement caching for regex and DSL expressions in the cache package, replacing previous sync.Map usage. Add unit tests for cache functionality, including eviction by capacity and retrieval of cached items. Update extractors and matchers to utilize the new cache system for improved performance and memory efficiency.

* Add tests for SetCapacities in cache package to ensure cache behavior on capacity changes

- Implemented TestSetCapacities_NoRebuildOnZero to verify that setting capacities to zero does not clear existing caches.
- Added TestSetCapacities_BeforeFirstUse to confirm that initial cache settings are respected and not overridden by subsequent capacity changes.

* Refactor matchers and update load test generator to use io package

- Removed maxRegexScanBytes constant from match.go.
- Replaced ioutil with io package in load_test.go for NopCloser usage.
- Restored TestValidate_AllowsInlineMultiline in load_test.go to ensure inline validation functionality.

* Add cancellation support in template execution and enhance test coverage

- Updated executeTemplateWithTargets to respect context cancellation.
- Introduced fakeTargetProvider and slowExecuter for testing.
- Added Test_executeTemplateWithTargets_RespectsCancellation to validate cancellation behavior during template execution.
This commit is contained in:
Nakul Bharti 2025-09-15 23:48:02 +05:30 committed by GitHub
parent d4f1a815ed
commit c4fa2c74c1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 1001 additions and 220 deletions

View File

@ -48,8 +48,15 @@ func (e *Engine) executeAllSelfContained(ctx context.Context, alltemplates []*te
// executeTemplateWithTargets executes a given template on x targets (with a internal targetpool(i.e concurrency))
func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templates.Template, target provider.InputProvider, results *atomic.Bool) {
// this is target pool i.e max target to execute
wg := e.workPool.InputPool(template.Type())
if e.workPool == nil {
e.workPool = e.GetWorkPool()
}
// Bounded worker pool using input concurrency
pool := e.workPool.InputPool(template.Type())
workerCount := 1
if pool != nil && pool.Size > 0 {
workerCount = pool.Size
}
var (
index uint32
@ -78,6 +85,41 @@ func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templ
currentInfo.Unlock()
}
// task represents a single target execution unit
type task struct {
index uint32
skip bool
value *contextargs.MetaInput
}
tasks := make(chan task)
var workersWg sync.WaitGroup
workersWg.Add(workerCount)
for i := 0; i < workerCount; i++ {
go func() {
defer workersWg.Done()
for t := range tasks {
func() {
defer cleanupInFlight(t.index)
select {
case <-ctx.Done():
return
default:
}
if t.skip {
return
}
match, err := e.executeTemplateOnInput(ctx, template, t.value)
if err != nil {
e.options.Logger.Warning().Msgf("[%s] Could not execute step on %s: %s\n", e.executerOpts.Colorizer.BrightBlue(template.ID), t.value.Input, err)
}
results.CompareAndSwap(false, match)
}()
}
}()
}
target.Iterate(func(scannedValue *contextargs.MetaInput) bool {
select {
case <-ctx.Done():
@ -128,43 +170,13 @@ func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templ
return true
}
wg.Add()
go func(index uint32, skip bool, value *contextargs.MetaInput) {
defer wg.Done()
defer cleanupInFlight(index)
if skip {
return
}
var match bool
var err error
ctxArgs := contextargs.New(ctx)
ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctx, ctxArgs)
switch template.Type() {
case types.WorkflowProtocol:
match = e.executeWorkflow(ctx, template.CompiledWorkflow)
default:
if e.Callback != nil {
if results, err := template.Executer.ExecuteWithResults(ctx); err == nil {
for _, result := range results {
e.Callback(result)
}
}
match = true
} else {
match, err = template.Executer.Execute(ctx)
}
}
if err != nil {
e.options.Logger.Warning().Msgf("[%s] Could not execute step on %s: %s\n", e.executerOpts.Colorizer.BrightBlue(template.ID), value.Input, err)
}
results.CompareAndSwap(false, match)
}(index, skip, scannedValue)
tasks <- task{index: index, skip: skip, value: scannedValue}
index++
return true
})
wg.Wait()
close(tasks)
workersWg.Wait()
// on completion marks the template as completed
currentInfo.Lock()
@ -202,26 +214,7 @@ func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*t
go func(template *templates.Template, value *contextargs.MetaInput, wg *syncutil.AdaptiveWaitGroup) {
defer wg.Done()
var match bool
var err error
ctxArgs := contextargs.New(ctx)
ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctx, ctxArgs)
switch template.Type() {
case types.WorkflowProtocol:
match = e.executeWorkflow(ctx, template.CompiledWorkflow)
default:
if e.Callback != nil {
if results, err := template.Executer.ExecuteWithResults(ctx); err == nil {
for _, result := range results {
e.Callback(result)
}
}
match = true
} else {
match, err = template.Executer.Execute(ctx)
}
}
match, err := e.executeTemplateOnInput(ctx, template, value)
if err != nil {
e.options.Logger.Warning().Msgf("[%s] Could not execute step on %s: %s\n", e.executerOpts.Colorizer.BrightBlue(template.ID), value.Input, err)
}
@ -229,3 +222,27 @@ func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*t
}(tpl, target, sg)
}
}
// executeTemplateOnInput performs template execution for a single input and returns match status and error
func (e *Engine) executeTemplateOnInput(ctx context.Context, template *templates.Template, value *contextargs.MetaInput) (bool, error) {
ctxArgs := contextargs.New(ctx)
ctxArgs.MetaInput = value
scanCtx := scan.NewScanContext(ctx, ctxArgs)
switch template.Type() {
case types.WorkflowProtocol:
return e.executeWorkflow(scanCtx, template.CompiledWorkflow), nil
default:
if e.Callback != nil {
results, err := template.Executer.ExecuteWithResults(scanCtx)
if err != nil {
return false, err
}
for _, result := range results {
e.Callback(result)
}
return len(results) > 0, nil
}
return template.Executer.Execute(scanCtx)
}
}

148
pkg/core/executors_test.go Normal file
View File

@ -0,0 +1,148 @@
package core
import (
"context"
"fmt"
"sync/atomic"
"testing"
"time"
inputtypes "github.com/projectdiscovery/nuclei/v3/pkg/input/types"
"github.com/projectdiscovery/nuclei/v3/pkg/output"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
"github.com/projectdiscovery/nuclei/v3/pkg/scan"
"github.com/projectdiscovery/nuclei/v3/pkg/templates"
tmpltypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
"github.com/projectdiscovery/nuclei/v3/pkg/types"
)
// fakeExecuter is a simple stub for protocols.Executer used to test executeTemplateOnInput
type fakeExecuter struct {
withResults bool
}
func (f *fakeExecuter) Compile() error { return nil }
func (f *fakeExecuter) Requests() int { return 1 }
func (f *fakeExecuter) Execute(ctx *scan.ScanContext) (bool, error) { return !f.withResults, nil }
func (f *fakeExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) {
if !f.withResults {
return nil, nil
}
return []*output.ResultEvent{{Host: "h"}}, nil
}
// newTestEngine creates a minimal Engine for tests
func newTestEngine() *Engine {
return New(&types.Options{})
}
func Test_executeTemplateOnInput_CallbackPath(t *testing.T) {
e := newTestEngine()
called := 0
e.Callback = func(*output.ResultEvent) { called++ }
tpl := &templates.Template{}
tpl.Executer = &fakeExecuter{withResults: true}
ok, err := e.executeTemplateOnInput(context.Background(), tpl, &contextargs.MetaInput{Input: "x"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !ok {
t.Fatalf("expected match true")
}
if called == 0 {
t.Fatalf("expected callback to be called")
}
}
func Test_executeTemplateOnInput_ExecutePath(t *testing.T) {
e := newTestEngine()
tpl := &templates.Template{}
tpl.Executer = &fakeExecuter{withResults: false}
ok, err := e.executeTemplateOnInput(context.Background(), tpl, &contextargs.MetaInput{Input: "x"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !ok {
t.Fatalf("expected match true from Execute path")
}
}
type fakeExecuterErr struct{}
func (f *fakeExecuterErr) Compile() error { return nil }
func (f *fakeExecuterErr) Requests() int { return 1 }
func (f *fakeExecuterErr) Execute(ctx *scan.ScanContext) (bool, error) { return false, nil }
func (f *fakeExecuterErr) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) {
return nil, fmt.Errorf("boom")
}
func Test_executeTemplateOnInput_CallbackErrorPropagates(t *testing.T) {
e := newTestEngine()
e.Callback = func(*output.ResultEvent) {}
tpl := &templates.Template{}
tpl.Executer = &fakeExecuterErr{}
ok, err := e.executeTemplateOnInput(context.Background(), tpl, &contextargs.MetaInput{Input: "x"})
if err == nil {
t.Fatalf("expected error to propagate")
}
if ok {
t.Fatalf("expected match to be false on error")
}
}
type fakeTargetProvider struct {
values []*contextargs.MetaInput
}
func (f *fakeTargetProvider) Count() int64 { return int64(len(f.values)) }
func (f *fakeTargetProvider) Iterate(cb func(value *contextargs.MetaInput) bool) {
for _, v := range f.values {
if !cb(v) {
return
}
}
}
func (f *fakeTargetProvider) Set(string, string) {}
func (f *fakeTargetProvider) SetWithProbe(string, string, inputtypes.InputLivenessProbe) error {
return nil
}
func (f *fakeTargetProvider) SetWithExclusions(string, string) error { return nil }
func (f *fakeTargetProvider) InputType() string { return "test" }
func (f *fakeTargetProvider) Close() {}
type slowExecuter struct{}
func (s *slowExecuter) Compile() error { return nil }
func (s *slowExecuter) Requests() int { return 1 }
func (s *slowExecuter) Execute(ctx *scan.ScanContext) (bool, error) {
select {
case <-ctx.Context().Done():
return false, ctx.Context().Err()
case <-time.After(200 * time.Millisecond):
return true, nil
}
}
func (s *slowExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) {
return nil, nil
}
func Test_executeTemplateWithTargets_RespectsCancellation(t *testing.T) {
e := newTestEngine()
e.SetExecuterOptions(&protocols.ExecutorOptions{Logger: e.Logger, ResumeCfg: types.NewResumeCfg(), ProtocolType: tmpltypes.HTTPProtocol})
tpl := &templates.Template{}
tpl.Executer = &slowExecuter{}
targets := &fakeTargetProvider{values: []*contextargs.MetaInput{{Input: "a"}, {Input: "b"}, {Input: "c"}}}
ctx, cancel := context.WithCancel(context.Background())
cancel()
var matched atomic.Bool
e.executeTemplateWithTargets(ctx, tpl, targets, &matched)
}

62
pkg/operators/cache/cache.go vendored Normal file
View File

@ -0,0 +1,62 @@
package cache
import (
"regexp"
"sync"
"github.com/Knetic/govaluate"
"github.com/projectdiscovery/gcache"
)
var (
initOnce sync.Once
mu sync.RWMutex
regexCap = 4096
dslCap = 4096
regexCache gcache.Cache[string, *regexp.Regexp]
dslCache gcache.Cache[string, *govaluate.EvaluableExpression]
)
func initCaches() {
initOnce.Do(func() {
regexCache = gcache.New[string, *regexp.Regexp](regexCap).LRU().Build()
dslCache = gcache.New[string, *govaluate.EvaluableExpression](dslCap).LRU().Build()
})
}
func SetCapacities(regexCapacity, dslCapacity int) {
// ensure caches are initialized under initOnce, so later Regex()/DSL() won't re-init
initCaches()
mu.Lock()
defer mu.Unlock()
if regexCapacity > 0 {
regexCap = regexCapacity
}
if dslCapacity > 0 {
dslCap = dslCapacity
}
if regexCapacity <= 0 && dslCapacity <= 0 {
return
}
// rebuild caches with new capacities
regexCache = gcache.New[string, *regexp.Regexp](regexCap).LRU().Build()
dslCache = gcache.New[string, *govaluate.EvaluableExpression](dslCap).LRU().Build()
}
func Regex() gcache.Cache[string, *regexp.Regexp] {
initCaches()
mu.RLock()
defer mu.RUnlock()
return regexCache
}
func DSL() gcache.Cache[string, *govaluate.EvaluableExpression] {
initCaches()
mu.RLock()
defer mu.RUnlock()
return dslCache
}

114
pkg/operators/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,114 @@
package cache
import (
"regexp"
"testing"
"github.com/Knetic/govaluate"
)
func TestRegexCache_SetGet(t *testing.T) {
// ensure init
c := Regex()
pattern := "abc(\n)?123"
re, err := regexp.Compile(pattern)
if err != nil {
t.Fatalf("compile: %v", err)
}
if err := c.Set(pattern, re); err != nil {
t.Fatalf("set: %v", err)
}
got, err := c.GetIFPresent(pattern)
if err != nil || got == nil {
t.Fatalf("get: %v got=%v", err, got)
}
if got.String() != re.String() {
t.Fatalf("mismatch: %s != %s", got.String(), re.String())
}
}
func TestDSLCache_SetGet(t *testing.T) {
c := DSL()
expr := "1 + 2 == 3"
ast, err := govaluate.NewEvaluableExpression(expr)
if err != nil {
t.Fatalf("dsl compile: %v", err)
}
if err := c.Set(expr, ast); err != nil {
t.Fatalf("set: %v", err)
}
got, err := c.GetIFPresent(expr)
if err != nil || got == nil {
t.Fatalf("get: %v got=%v", err, got)
}
if got.String() != ast.String() {
t.Fatalf("mismatch: %s != %s", got.String(), ast.String())
}
}
func TestRegexCache_EvictionByCapacity(t *testing.T) {
SetCapacities(3, 3)
c := Regex()
for i := 0; i < 5; i++ {
k := string(rune('a' + i))
re := regexp.MustCompile(k)
_ = c.Set(k, re)
}
// last 3 keys expected to remain under LRU: 'c','d','e'
if _, err := c.GetIFPresent("a"); err == nil {
t.Fatalf("expected 'a' to be evicted")
}
if _, err := c.GetIFPresent("b"); err == nil {
t.Fatalf("expected 'b' to be evicted")
}
if _, err := c.GetIFPresent("c"); err != nil {
t.Fatalf("expected 'c' present")
}
}
func TestSetCapacities_NoRebuildOnZero(t *testing.T) {
// init
SetCapacities(4, 4)
c1 := Regex()
_ = c1.Set("k", regexp.MustCompile("k"))
if _, err := c1.GetIFPresent("k"); err != nil {
t.Fatalf("expected key present: %v", err)
}
// zero changes should not rebuild/clear caches
SetCapacities(0, 0)
c2 := Regex()
if _, err := c2.GetIFPresent("k"); err != nil {
t.Fatalf("key lost after zero-capacity SetCapacities: %v", err)
}
}
func TestSetCapacities_BeforeFirstUse(t *testing.T) {
// This should not be overridden by later initCaches
SetCapacities(2, 0)
c := Regex()
_ = c.Set("a", regexp.MustCompile("a"))
_ = c.Set("b", regexp.MustCompile("b"))
_ = c.Set("c", regexp.MustCompile("c"))
if _, err := c.GetIFPresent("a"); err == nil {
t.Fatalf("expected 'a' to be evicted under cap=2")
}
}
func TestSetCapacities_ConcurrentAccess(t *testing.T) {
SetCapacities(64, 64)
stop := make(chan struct{})
go func() {
for i := 0; i < 5000; i++ {
_ = Regex().Set("k"+string(rune('a'+(i%26))), regexp.MustCompile("a"))
_, _ = Regex().GetIFPresent("k" + string(rune('a'+(i%26))))
_, _ = DSL().GetIFPresent("1+2==3")
}
close(stop)
}()
for i := 0; i < 200; i++ {
SetCapacities(64+(i%5), 64+((i+1)%5))
}
<-stop
}

View File

@ -7,6 +7,7 @@ import (
"github.com/Knetic/govaluate"
"github.com/itchyny/gojq"
"github.com/projectdiscovery/nuclei/v3/pkg/operators/cache"
"github.com/projectdiscovery/nuclei/v3/pkg/operators/common/dsl"
)
@ -20,10 +21,15 @@ func (e *Extractor) CompileExtractors() error {
e.extractorType = computedType
// Compile the regexes
for _, regex := range e.Regex {
if cached, err := cache.Regex().GetIFPresent(regex); err == nil && cached != nil {
e.regexCompiled = append(e.regexCompiled, cached)
continue
}
compiled, err := regexp.Compile(regex)
if err != nil {
return fmt.Errorf("could not compile regex: %s", regex)
}
_ = cache.Regex().Set(regex, compiled)
e.regexCompiled = append(e.regexCompiled, compiled)
}
for i, kval := range e.KVal {
@ -43,10 +49,15 @@ func (e *Extractor) CompileExtractors() error {
}
for _, dslExp := range e.DSL {
if cached, err := cache.DSL().GetIFPresent(dslExp); err == nil && cached != nil {
e.dslCompiled = append(e.dslCompiled, cached)
continue
}
compiled, err := govaluate.NewEvaluableExpressionWithFunctions(dslExp, dsl.HelperFunctions)
if err != nil {
return &dsl.CompilationError{DslSignature: dslExp, WrappedError: err}
}
_ = cache.DSL().Set(dslExp, compiled)
e.dslCompiled = append(e.dslCompiled, compiled)
}

View File

@ -17,9 +17,19 @@ func (e *Extractor) ExtractRegex(corpus string) map[string]struct{} {
groupPlusOne := e.RegexGroup + 1
for _, regex := range e.regexCompiled {
matches := regex.FindAllStringSubmatch(corpus, -1)
// skip prefix short-circuit for case-insensitive patterns
rstr := regex.String()
if !strings.Contains(rstr, "(?i") {
if prefix, ok := regex.LiteralPrefix(); ok && prefix != "" {
if !strings.Contains(corpus, prefix) {
continue
}
}
}
for _, match := range matches {
submatches := regex.FindAllStringSubmatch(corpus, -1)
for _, match := range submatches {
if len(match) < groupPlusOne {
continue
}

View File

@ -7,7 +7,7 @@ import (
"strings"
"github.com/Knetic/govaluate"
"github.com/projectdiscovery/nuclei/v3/pkg/operators/cache"
"github.com/projectdiscovery/nuclei/v3/pkg/operators/common/dsl"
)
@ -42,12 +42,17 @@ func (matcher *Matcher) CompileMatchers() error {
matcher.Part = "body"
}
// Compile the regexes
// Compile the regexes (with shared cache)
for _, regex := range matcher.Regex {
if cached, err := cache.Regex().GetIFPresent(regex); err == nil && cached != nil {
matcher.regexCompiled = append(matcher.regexCompiled, cached)
continue
}
compiled, err := regexp.Compile(regex)
if err != nil {
return fmt.Errorf("could not compile regex: %s", regex)
}
_ = cache.Regex().Set(regex, compiled)
matcher.regexCompiled = append(matcher.regexCompiled, compiled)
}
@ -60,12 +65,17 @@ func (matcher *Matcher) CompileMatchers() error {
}
}
// Compile the dsl expressions
// Compile the dsl expressions (with shared cache)
for _, dslExpression := range matcher.DSL {
if cached, err := cache.DSL().GetIFPresent(dslExpression); err == nil && cached != nil {
matcher.dslCompiled = append(matcher.dslCompiled, cached)
continue
}
compiledExpression, err := govaluate.NewEvaluableExpressionWithFunctions(dslExpression, dsl.HelperFunctions)
if err != nil {
return &dsl.CompilationError{DslSignature: dslExpression, WrappedError: err}
}
_ = cache.DSL().Set(dslExpression, compiledExpression)
matcher.dslCompiled = append(matcher.dslCompiled, compiledExpression)
}

View File

@ -106,10 +106,33 @@ func (matcher *Matcher) MatchRegex(corpus string) (bool, []string) {
var matchedRegexes []string
// Iterate over all the regexes accepted as valid
for i, regex := range matcher.regexCompiled {
// Continue if the regex doesn't match
if !regex.MatchString(corpus) {
// If we are in an AND request and a match failed,
// return false as the AND condition fails on any single mismatch.
// Literal prefix short-circuit
rstr := regex.String()
if !strings.Contains(rstr, "(?i") { // covers (?i) and (?i:
if prefix, ok := regex.LiteralPrefix(); ok && prefix != "" {
if !strings.Contains(corpus, prefix) {
switch matcher.condition {
case ANDCondition:
return false, []string{}
case ORCondition:
continue
}
}
}
}
// Fast OR-path: return first match without full scan
if matcher.condition == ORCondition && !matcher.MatchAll {
m := regex.FindAllString(corpus, 1)
if len(m) == 0 {
continue
}
return true, m
}
// Single scan: get all matches directly
currentMatches := regex.FindAllString(corpus, -1)
if len(currentMatches) == 0 {
switch matcher.condition {
case ANDCondition:
return false, []string{}
@ -118,12 +141,7 @@ func (matcher *Matcher) MatchRegex(corpus string) (bool, []string) {
}
}
currentMatches := regex.FindAllString(corpus, -1)
// If the condition was an OR, return on the first match.
if matcher.condition == ORCondition && !matcher.MatchAll {
return true, currentMatches
}
// If the condition was an OR (and MatchAll true), we still need to gather all
matchedRegexes = append(matchedRegexes, currentMatches...)
// If we are at the end of the regex, return with true

View File

@ -84,7 +84,7 @@ func TestMatcher_MatchDSL(t *testing.T) {
values := []string{"PING", "pong"}
for value := range values {
for _, value := range values {
isMatched := m.MatchDSL(map[string]interface{}{"body": value, "VARIABLE": value})
require.True(t, isMatched)
}
@ -209,3 +209,66 @@ func TestMatcher_MatchXPath_XML(t *testing.T) {
isMatched = m.MatchXPath("<h1> not right <q id=2/>notvalid")
require.False(t, isMatched, "Invalid xpath did not return false")
}
func TestMatchRegex_CaseInsensitivePrefixSkip(t *testing.T) {
m := &Matcher{Type: MatcherTypeHolder{MatcherType: RegexMatcher}, Condition: "or", Regex: []string{"(?i)abc"}}
err := m.CompileMatchers()
require.NoError(t, err)
ok, got := m.MatchRegex("zzz AbC yyy")
require.True(t, ok)
require.NotEmpty(t, got)
}
func TestMatchStatusCodeAndSize(t *testing.T) {
mStatus := &Matcher{Status: []int{200, 302}}
require.True(t, mStatus.MatchStatusCode(200))
require.True(t, mStatus.MatchStatusCode(302))
require.False(t, mStatus.MatchStatusCode(404))
mSize := &Matcher{Size: []int{5, 10}}
require.True(t, mSize.MatchSize(5))
require.False(t, mSize.MatchSize(7))
}
func TestMatchBinary_AND_OR(t *testing.T) {
// AND should fail if any binary not present
mAnd := &Matcher{Type: MatcherTypeHolder{MatcherType: BinaryMatcher}, Condition: "and", Binary: []string{"50494e47", "414141"}} // "PING", "AAA"
require.NoError(t, mAnd.CompileMatchers())
ok, _ := mAnd.MatchBinary("PING")
require.False(t, ok)
// OR should succeed if any present
mOr := &Matcher{Type: MatcherTypeHolder{MatcherType: BinaryMatcher}, Condition: "or", Binary: []string{"414141", "50494e47"}} // "AAA", "PING"
require.NoError(t, mOr.CompileMatchers())
ok, got := mOr.MatchBinary("xxPINGyy")
require.True(t, ok)
require.NotEmpty(t, got)
}
func TestMatchRegex_LiteralPrefixShortCircuit(t *testing.T) {
// AND: first regex has literal prefix "abc"; corpus lacks it => early false
mAnd := &Matcher{Type: MatcherTypeHolder{MatcherType: RegexMatcher}, Condition: "and", Regex: []string{"abc[0-9]*", "[0-9]{2}"}}
require.NoError(t, mAnd.CompileMatchers())
ok, matches := mAnd.MatchRegex("zzz 12 yyy")
require.False(t, ok)
require.Empty(t, matches)
// OR: first regex skipped due to missing prefix, second matches => true
mOr := &Matcher{Type: MatcherTypeHolder{MatcherType: RegexMatcher}, Condition: "or", Regex: []string{"abc[0-9]*", "[0-9]{2}"}}
require.NoError(t, mOr.CompileMatchers())
ok, matches = mOr.MatchRegex("zzz 12 yyy")
require.True(t, ok)
require.Equal(t, []string{"12"}, matches)
}
func TestMatcher_MatchDSL_ErrorHandling(t *testing.T) {
// First expression errors (division by zero), second is true
bad, err := govaluate.NewEvaluableExpression("1 / 0")
require.NoError(t, err)
good, err := govaluate.NewEvaluableExpression("1 == 1")
require.NoError(t, err)
m := &Matcher{Type: MatcherTypeHolder{MatcherType: DSLMatcher}, Condition: "or", dslCompiled: []*govaluate.EvaluableExpression{bad, good}}
require.NoError(t, m.CompileMatchers())
ok := m.MatchDSL(map[string]interface{}{})
require.True(t, ok)
}

View File

@ -0,0 +1,26 @@
package generators
import "testing"
func TestAttackTypeHelpers(t *testing.T) {
// GetSupportedAttackTypes should include three values
types := GetSupportedAttackTypes()
if len(types) != 3 {
t.Fatalf("expected 3 types, got %d", len(types))
}
// toAttackType valid
if got, err := toAttackType("pitchfork"); err != nil || got != PitchForkAttack {
t.Fatalf("toAttackType failed: %v %v", got, err)
}
// toAttackType invalid
if _, err := toAttackType("nope"); err == nil {
t.Fatalf("expected error for invalid attack type")
}
// normalizeValue and String
if normalizeValue(" ClusterBomb ") != "clusterbomb" {
t.Fatalf("normalizeValue failed")
}
if ClusterBombAttack.String() != "clusterbomb" {
t.Fatalf("String failed")
}
}

View File

@ -0,0 +1,38 @@
package generators
import (
"os"
"testing"
)
func TestParseEnvVars(t *testing.T) {
old := os.Environ()
// set a scoped env var
_ = os.Setenv("NUCLEI_TEST_K", "V1")
t.Cleanup(func() {
// restore
for _, kv := range old {
parts := kv
_ = parts // nothing, environment already has superset; best-effort cleanup below
}
_ = os.Unsetenv("NUCLEI_TEST_K")
})
vars := parseEnvVars()
if vars["NUCLEI_TEST_K"] != "V1" {
t.Fatalf("expected V1, got %v", vars["NUCLEI_TEST_K"])
}
}
func TestEnvVarsMemoization(t *testing.T) {
// reset memoized map
envVars = nil
_ = os.Setenv("NUCLEI_TEST_MEMO", "A")
t.Cleanup(func() { _ = os.Unsetenv("NUCLEI_TEST_MEMO") })
v1 := EnvVars()["NUCLEI_TEST_MEMO"]
// change env after memoization
_ = os.Setenv("NUCLEI_TEST_MEMO", "B")
v2 := EnvVars()["NUCLEI_TEST_MEMO"]
if v1 != "A" || v2 != "A" {
t.Fatalf("memoization failed: %v %v", v1, v2)
}
}

View File

@ -17,6 +17,20 @@ func (generator *PayloadGenerator) loadPayloads(payloads map[string]interface{},
for name, payload := range payloads {
switch pt := payload.(type) {
case string:
// Fast path: if no newline, treat as file path
if !strings.ContainsRune(pt, '\n') {
file, err := generator.options.LoadHelperFile(pt, templatePath, generator.catalog)
if err != nil {
return nil, errors.Wrap(err, "could not load payload file")
}
payloads, err := generator.loadPayloadsFromFile(file)
if err != nil {
return nil, errors.Wrap(err, "could not load payloads")
}
loadedPayloads[name] = payloads
break
}
// Multiline inline payloads
elements := strings.Split(pt, "\n")
//golint:gomnd // this is not a magic number
if len(elements) >= 2 {

View File

@ -1,120 +1,108 @@
package generators
import (
"os"
"os/exec"
"path/filepath"
"io"
"strings"
"testing"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/disk"
osutils "github.com/projectdiscovery/utils/os"
"github.com/stretchr/testify/require"
"github.com/pkg/errors"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog"
"github.com/projectdiscovery/nuclei/v3/pkg/types"
)
func TestLoadPayloads(t *testing.T) {
// since we are changing value of global variable i.e templates directory
// run this test as subprocess
if os.Getenv("LOAD_PAYLOAD_NO_ACCESS") != "1" {
cmd := exec.Command(os.Args[0], "-test.run=TestLoadPayloadsWithAccess")
cmd.Env = append(os.Environ(), "LOAD_PAYLOAD_NO_ACCESS=1")
err := cmd.Run()
if e, ok := err.(*exec.ExitError); ok && !e.Success() {
return
type fakeCatalog struct{ catalog.Catalog }
func (f *fakeCatalog) OpenFile(filename string) (io.ReadCloser, error) {
return nil, errors.New("not used")
}
func (f *fakeCatalog) GetTemplatePath(target string) ([]string, error) { return nil, nil }
func (f *fakeCatalog) GetTemplatesPath(definitions []string) ([]string, map[string]error) {
return nil, nil
}
func (f *fakeCatalog) ResolvePath(templateName, second string) (string, error) {
return templateName, nil
}
func newTestGenerator() *PayloadGenerator {
opts := types.DefaultOptions()
// inject helper loader function
opts.LoadHelperFileFunction = func(path, templatePath string, _ catalog.Catalog) (io.ReadCloser, error) {
switch path {
case "fileA.txt":
return io.NopCloser(strings.NewReader("one\n two\n\nthree\n")), nil
default:
return io.NopCloser(strings.NewReader("x\ny\nz\n")), nil
}
}
return &PayloadGenerator{options: opts, catalog: &fakeCatalog{}}
}
func TestLoadPayloads_FastPathFile(t *testing.T) {
g := newTestGenerator()
out, err := g.loadPayloads(map[string]interface{}{"A": "fileA.txt"}, "")
if err != nil {
t.Fatalf("process ran with err %v, want exit status 1", err)
t.Fatalf("err: %v", err)
}
got := out["A"]
if len(got) != 3 || got[0] != "one" || got[1] != " two" || got[2] != "three" {
t.Fatalf("unexpected: %#v", got)
}
templateDir := getTemplatesDir(t)
config.DefaultConfig.SetTemplatesDir(templateDir)
generator := &PayloadGenerator{catalog: disk.NewCatalog(templateDir), options: getOptions(false)}
fullpath := filepath.Join(templateDir, "payloads.txt")
// Test sandbox
t.Run("templates-directory", func(t *testing.T) {
// testcase when loading file from template directory and template file is in root
// expected to succeed
values, err := generator.loadPayloads(map[string]interface{}{
"new": fullpath,
}, "/test")
require.NoError(t, err, "could not load payloads")
require.Equal(t, map[string][]string{"new": {"test", "another"}}, values, "could not get values")
})
t.Run("templates-path-relative", func(t *testing.T) {
// testcase when loading file from template directory and template file is current working directory
// expected to fail since this is LFI
_, err := generator.loadPayloads(map[string]interface{}{
"new": "../../../../../../../../../etc/passwd",
}, ".")
require.Error(t, err, "could load payloads")
})
t.Run("template-directory", func(t *testing.T) {
// testcase when loading file from template directory and template file is inside template directory
// expected to succeed
values, err := generator.loadPayloads(map[string]interface{}{
"new": fullpath,
}, filepath.Join(templateDir, "test.yaml"))
require.NoError(t, err, "could not load payloads")
require.Equal(t, map[string][]string{"new": {"test", "another"}}, values, "could not get values")
})
t.Run("invalid", func(t *testing.T) {
// testcase when loading file from /etc/passwd and template file is at root i.e /
// expected to fail since this is LFI
values, err := generator.loadPayloads(map[string]interface{}{
"new": "/etc/passwd",
}, "/random")
require.Error(t, err, "could load payloads got %v", values)
require.Equal(t, 0, len(values), "could get values")
// testcase when loading file from template directory and template file is at root i.e /
// expected to succeed
values, err = generator.loadPayloads(map[string]interface{}{
"new": fullpath,
}, "/random")
require.NoError(t, err, "could load payloads %v", values)
require.Equal(t, 1, len(values), "could get values")
require.Equal(t, []string{"test", "another"}, values["new"], "could get values")
})
}
func TestLoadPayloadsWithAccess(t *testing.T) {
// since we are changing value of global variable i.e templates directory
// run this test as subprocess
if os.Getenv("LOAD_PAYLOAD_WITH_ACCESS") != "1" {
cmd := exec.Command(os.Args[0], "-test.run=TestLoadPayloadsWithAccess")
cmd.Env = append(os.Environ(), "LOAD_PAYLOAD_WITH_ACCESS=1")
err := cmd.Run()
if e, ok := err.(*exec.ExitError); ok && !e.Success() {
return
}
func TestLoadPayloads_InlineMultiline(t *testing.T) {
g := newTestGenerator()
inline := "a\nb\n"
out, err := g.loadPayloads(map[string]interface{}{"B": inline}, "")
if err != nil {
t.Fatalf("process ran with err %v, want exit status 1", err)
t.Fatalf("err: %v", err)
}
got := out["B"]
if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "" {
t.Fatalf("unexpected: %#v", got)
}
templateDir := getTemplatesDir(t)
config.DefaultConfig.SetTemplatesDir(templateDir)
generator := &PayloadGenerator{catalog: disk.NewCatalog(templateDir), options: getOptions(true)}
t.Run("no-sandbox-unix", func(t *testing.T) {
if osutils.IsWindows() {
return
}
_, err := generator.loadPayloads(map[string]interface{}{
"new": "/etc/passwd",
}, "/random")
require.NoError(t, err, "could load payloads")
})
}
func getTemplatesDir(t *testing.T) string {
tempdir, err := os.MkdirTemp("", "templates-*")
require.NoError(t, err, "could not create temp dir")
fullpath := filepath.Join(tempdir, "payloads.txt")
err = os.WriteFile(fullpath, []byte("test\nanother"), 0777)
require.NoError(t, err, "could not write payload")
return tempdir
func TestLoadPayloads_SingleLineFallsBackToFile(t *testing.T) {
g := newTestGenerator()
inline := "fileA.txt" // single line, should be treated as file path
out, err := g.loadPayloads(map[string]interface{}{"C": inline}, "")
if err != nil {
t.Fatalf("err: %v", err)
}
got := out["C"]
if len(got) != 3 {
t.Fatalf("unexpected len: %d", len(got))
}
}
func TestLoadPayloads_InterfaceSlice(t *testing.T) {
g := newTestGenerator()
out, err := g.loadPayloads(map[string]interface{}{"D": []interface{}{"p", "q"}}, "")
if err != nil {
t.Fatalf("err: %v", err)
}
got := out["D"]
if len(got) != 2 || got[0] != "p" || got[1] != "q" {
t.Fatalf("unexpected: %#v", got)
}
}
func TestLoadPayloadsFromFile_SkipsEmpty(t *testing.T) {
g := newTestGenerator()
rc := io.NopCloser(strings.NewReader("a\n\n\n b \n"))
lines, err := g.loadPayloadsFromFile(rc)
if err != nil {
t.Fatalf("err: %v", err)
}
if len(lines) != 2 || lines[0] != "a" || lines[1] != " b " {
t.Fatalf("unexpected: %#v", lines)
}
}
func TestValidate_AllowsInlineMultiline(t *testing.T) {
g := newTestGenerator()
inline := "x\ny\n"
if err := g.validate(map[string]interface{}{"E": inline}, ""); err != nil {
t.Fatalf("validate rejected inline multiline: %v", err)
}
}

View File

@ -14,3 +14,32 @@ func TestMergeMapsMany(t *testing.T) {
"c": {"5"},
}, got, "could not get correct merged map")
}
func TestMergeMapsAndExpand(t *testing.T) {
m1 := map[string]interface{}{"a": "1"}
m2 := map[string]interface{}{"b": "2"}
out := MergeMaps(m1, m2)
if out["a"].(string) != "1" || out["b"].(string) != "2" {
t.Fatalf("unexpected merge: %#v", out)
}
flat := map[string]string{"x": "y"}
exp := ExpandMapValues(flat)
if len(exp["x"]) != 1 || exp["x"][0] != "y" {
t.Fatalf("unexpected expand: %#v", exp)
}
}
func TestIteratorRemaining(t *testing.T) {
g, err := New(map[string]interface{}{"k": []interface{}{"a", "b"}}, BatteringRamAttack, "", nil, "", nil)
if err != nil {
t.Fatalf("new: %v", err)
}
it := g.NewIterator()
if it.Total() != 2 || it.Remaining() != 2 {
t.Fatalf("unexpected totals: %d %d", it.Total(), it.Remaining())
}
_, _ = it.Value()
if it.Remaining() != 1 {
t.Fatalf("unexpected remaining after one: %d", it.Remaining())
}
}

View File

@ -1,7 +1,6 @@
package generators
import (
"errors"
"fmt"
"path/filepath"
"strings"
@ -17,9 +16,8 @@ func (g *PayloadGenerator) validate(payloads map[string]interface{}, templatePat
for name, payload := range payloads {
switch payloadType := payload.(type) {
case string:
// check if it's a multiline string list
if len(strings.Split(payloadType, "\n")) != 1 {
return errors.New("invalid number of lines in payload")
if strings.ContainsRune(payloadType, '\n') {
continue
}
// For historical reasons, "validate" checks to see if the payload file exist.

View File

@ -92,9 +92,8 @@ type generatedRequest struct {
// setReqURLPattern sets the url request pattern for the generated request
func (gr *generatedRequest) setReqURLPattern(reqURLPattern string) {
data := strings.Split(reqURLPattern, "\n")
if len(data) > 1 {
reqURLPattern = strings.TrimSpace(data[0])
if idx := strings.IndexByte(reqURLPattern, '\n'); idx >= 0 {
reqURLPattern = strings.TrimSpace(reqURLPattern[:idx])
// this is raw request (if it has 3 parts after strings.Fields then its valid only use 2nd part)
parts := strings.Fields(reqURLPattern)
if len(parts) >= 3 {

View File

@ -8,6 +8,7 @@ import (
"fmt"
"io"
"strings"
"sync"
"github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/nuclei/v3/pkg/authprovider/authx"
@ -17,6 +18,8 @@ import (
urlutil "github.com/projectdiscovery/utils/url"
)
var bufferPool = sync.Pool{New: func() any { return new(bytes.Buffer) }}
// Request defines a basic HTTP raw request
type Request struct {
FullURL string
@ -270,13 +273,17 @@ func (r *Request) TryFillCustomHeaders(headers []string) error {
if newLineIndex > 0 {
newLineIndex += hostHeaderIndex + 2
// insert custom headers
var buf bytes.Buffer
buf := bufferPool.Get().(*bytes.Buffer)
buf.Reset()
buf.Write(r.UnsafeRawBytes[:newLineIndex])
for _, header := range headers {
buf.WriteString(fmt.Sprintf("%s\r\n", header))
buf.WriteString(header)
buf.WriteString("\r\n")
}
buf.Write(r.UnsafeRawBytes[newLineIndex:])
r.UnsafeRawBytes = buf.Bytes()
r.UnsafeRawBytes = append([]byte(nil), buf.Bytes()...)
buf.Reset()
bufferPool.Put(buf)
return nil
}
return errors.New("no new line found at the end of host header")
@ -301,9 +308,10 @@ func (r *Request) ApplyAuthStrategy(strategy authx.AuthStrategy) {
parsed.Params.Add(p.Key, p.Value)
}
case *authx.CookiesAuthStrategy:
var buff bytes.Buffer
buff := bufferPool.Get().(*bytes.Buffer)
buff.Reset()
for _, cookie := range s.Data.Cookies {
buff.WriteString(fmt.Sprintf("%s=%s; ", cookie.Key, cookie.Value))
fmt.Fprintf(buff, "%s=%s; ", cookie.Key, cookie.Value)
}
if buff.Len() > 0 {
if val, ok := r.Headers["Cookie"]; ok {
@ -312,6 +320,7 @@ func (r *Request) ApplyAuthStrategy(strategy authx.AuthStrategy) {
r.Headers["Cookie"] = buff.String()
}
}
bufferPool.Put(buff)
case *authx.HeadersAuthStrategy:
for _, header := range s.Data.Headers {
r.Headers[header.Key] = header.Value

View File

@ -7,6 +7,21 @@ import (
"github.com/stretchr/testify/require"
)
func TestTryFillCustomHeaders_BufferDetached(t *testing.T) {
r := &Request{
UnsafeRawBytes: []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\nBody"),
}
// first fill
err := r.TryFillCustomHeaders([]string{"X-Test: 1"})
require.NoError(t, err, "unexpected error on first call")
prev := r.UnsafeRawBytes
prevStr := string(prev) // content snapshot
err = r.TryFillCustomHeaders([]string{"X-Another: 2"})
require.NoError(t, err, "unexpected error on second call")
require.Equal(t, prevStr, string(prev), "first slice mutated after second call; buffer not detached")
require.NotEqual(t, prevStr, string(r.UnsafeRawBytes), "request bytes did not change after second call")
}
func TestParseRawRequestWithPort(t *testing.T) {
request, err := Parse(`GET /gg/phpinfo.php HTTP/1.1
Host: {{Hostname}}:123

View File

@ -240,6 +240,48 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
}
})
// bounded worker-pool to avoid spawning one goroutine per payload
type task struct {
req *generatedRequest
updatedInput *contextargs.Context
}
var workersWg sync.WaitGroup
currentWorkers := maxWorkers
tasks := make(chan task, maxWorkers)
spawnWorker := func(ctx context.Context) {
workersWg.Add(1)
go func() {
defer workersWg.Done()
for t := range tasks {
select {
case <-ctx.Done():
return
default:
}
if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(t.updatedInput) || spmHandler.Cancelled() {
continue
}
spmHandler.Acquire()
if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(t.updatedInput) || spmHandler.Cancelled() {
spmHandler.Release()
continue
}
request.options.RateLimitTake()
select {
case <-spmHandler.Done():
spmHandler.Release()
continue
case spmHandler.ResultChan <- request.executeRequest(t.updatedInput, t.req, make(map[string]interface{}), false, wrappedCallback, 0):
spmHandler.Release()
}
}
}()
}
for i := 0; i < currentWorkers; i++ {
spawnWorker(ctx)
}
// iterate payloads and make requests
generator := request.newGenerator(false)
for {
@ -259,6 +301,13 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
if err := spmHandler.Resize(input.Context(), request.options.Options.PayloadConcurrency); err != nil {
return err
}
// if payload concurrency increased, add more workers
if spmHandler.Size() > currentWorkers {
for i := 0; i < spmHandler.Size()-currentWorkers; i++ {
spawnWorker(ctx)
}
currentWorkers = spmHandler.Size()
}
}
// break if stop at first match is found or host is unresponsive
@ -284,29 +333,21 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
spmHandler.Cancel()
return nil
}
spmHandler.Acquire()
go func(httpRequest *generatedRequest) {
defer spmHandler.Release()
if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(updatedInput) || spmHandler.Cancelled() {
return
}
// putting ratelimiter here prevents any unnecessary waiting if any
request.options.RateLimitTake()
// after ratelimit take, check if we need to stop
if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(updatedInput) || spmHandler.Cancelled() {
return
}
select {
case <-spmHandler.Done():
return
case spmHandler.ResultChan <- request.executeRequest(input, httpRequest, make(map[string]interface{}), false, wrappedCallback, 0):
return
close(tasks)
workersWg.Wait()
spmHandler.Wait()
if spmHandler.FoundFirstMatch() {
return nil
}
return multierr.Combine(spmHandler.CombinedResults()...)
case tasks <- task{req: generatedHttpRequest, updatedInput: updatedInput}:
}
}(generatedHttpRequest)
request.options.Progress.IncrementRequests()
}
close(tasks)
workersWg.Wait()
spmHandler.Wait()
if spmHandler.FoundFirstMatch() {
// ignore any context cancellation and in-transit execution errors

View File

@ -5,6 +5,7 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/stretchr/testify/require"
@ -257,3 +258,116 @@ func TestReqURLPattern(t *testing.T) {
require.NotEmpty(t, finalEvent.Results[0].ReqURLPattern, "could not get req url pattern")
require.Equal(t, `/{{rand_char("abc")}}/{{interactsh-url}}/123?query={{rand_int(1, 10)}}&data={{randstr}}`, finalEvent.Results[0].ReqURLPattern)
}
// fakeHostErrorsCache implements hosterrorscache.CacheInterface minimally for tests
type fakeHostErrorsCache struct{}
func (f *fakeHostErrorsCache) SetVerbose(bool) {}
func (f *fakeHostErrorsCache) Close() {}
func (f *fakeHostErrorsCache) Remove(*contextargs.Context) {}
func (f *fakeHostErrorsCache) MarkFailed(string, *contextargs.Context, error) {}
func (f *fakeHostErrorsCache) MarkFailedOrRemove(string, *contextargs.Context, error) {
}
// Check always returns true to simulate an already unresponsive host
func (f *fakeHostErrorsCache) Check(string, *contextargs.Context) bool { return true }
func TestExecuteParallelHTTP_StopAtFirstMatch(t *testing.T) {
options := testutils.DefaultOptions
testutils.Init(options)
// server that always matches
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprintf(w, "match")
}))
defer ts.Close()
templateID := "parallel-stop-first"
req := &Request{
ID: templateID,
Method: HTTPMethodTypeHolder{MethodType: HTTPGet},
Path: []string{"{{BaseURL}}/p?x={{v}}"},
Threads: 2,
Payloads: map[string]interface{}{
"v": []string{"1", "2"},
},
Operators: operators.Operators{
Matchers: []*matchers.Matcher{{
Part: "body",
Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher},
Words: []string{"match"},
}},
},
StopAtFirstMatch: true,
}
executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{
ID: templateID,
Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"},
})
err := req.Compile(executerOpts)
require.NoError(t, err)
var matches int32
metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
if event.OperatorsResult != nil && event.OperatorsResult.Matched {
atomic.AddInt32(&matches, 1)
}
})
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&matches), "expected only first match to be processed")
}
func TestExecuteParallelHTTP_SkipOnUnresponsiveFromCache(t *testing.T) {
options := testutils.DefaultOptions
testutils.Init(options)
// server that would match if reached
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprintf(w, "match")
}))
defer ts.Close()
templateID := "parallel-skip-unresponsive"
req := &Request{
ID: templateID,
Method: HTTPMethodTypeHolder{MethodType: HTTPGet},
Path: []string{"{{BaseURL}}/p?x={{v}}"},
Threads: 2,
Payloads: map[string]interface{}{
"v": []string{"1", "2"},
},
Operators: operators.Operators{
Matchers: []*matchers.Matcher{{
Part: "body",
Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher},
Words: []string{"match"},
}},
},
}
executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{
ID: templateID,
Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"},
})
// inject fake host errors cache that forces skip
executerOpts.HostErrorsCache = &fakeHostErrorsCache{}
err := req.Compile(executerOpts)
require.NoError(t, err)
var matches int32
metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
if event.OperatorsResult != nil && event.OperatorsResult.Matched {
atomic.AddInt32(&matches, 1)
}
})
require.NoError(t, err)
require.Equal(t, int32(0), atomic.LoadInt32(&matches), "expected no matches when host is marked unresponsive")
}

View File

@ -811,8 +811,11 @@ func beautifyJavascript(code string) string {
}
func prettyPrint(templateId string, buff string) {
if buff == "" {
return
}
lines := strings.Split(buff, "\n")
final := []string{}
final := make([]string, 0, len(lines))
for _, v := range lines {
if v != "" {
final = append(final, "\t"+v)

View File

@ -192,11 +192,13 @@ func New(options *Options, db string, doNotDedupe bool) (Client, error) {
}
}
if db != "" || len(client.trackers) > 0 || len(client.exporters) > 0 {
storage, err := dedupe.New(db)
if err != nil {
return nil, err
}
client.dedupe = storage
}
return client, nil
}

View File

@ -12,7 +12,9 @@ type Cache struct {
// New returns a new templates cache
func NewCache() *Cache {
return &Cache{items: mapsutil.NewSyncLockMap[string, parsedTemplate]()}
return &Cache{
items: mapsutil.NewSyncLockMap[string, parsedTemplate](),
}
}
type parsedTemplate struct {
@ -33,7 +35,31 @@ func (t *Cache) Has(template string) (*Template, []byte, error) {
// Store stores a template with data and error
func (t *Cache) Store(id string, tpl *Template, raw []byte, err error) {
_ = t.items.Set(id, parsedTemplate{template: tpl, raw: conversion.String(raw), err: err})
entry := parsedTemplate{
template: tpl,
err: err,
raw: conversion.String(raw),
}
_ = t.items.Set(id, entry)
}
// StoreWithoutRaw stores a template without raw data for memory efficiency
func (t *Cache) StoreWithoutRaw(id string, tpl *Template, err error) {
entry := parsedTemplate{
template: tpl,
err: err,
raw: "",
}
_ = t.items.Set(id, entry)
}
// Get returns only the template without raw bytes
func (t *Cache) Get(id string) (*Template, error) {
value, ok := t.items.Get(id)
if !ok {
return nil, nil
}
return value.template, value.err
}
// Purge the cache

View File

@ -404,7 +404,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err
// for code protocol requests
for _, request := range template.RequestsCode {
// simple test to check if source is a file or a snippet
if len(strings.Split(request.Source, "\n")) == 1 && fileutil.FileExists(request.Source) {
if !strings.ContainsRune(request.Source, '\n') && fileutil.FileExists(request.Source) {
if val, ok := loadFile(request.Source); ok {
template.ImportedFiles = append(template.ImportedFiles, request.Source)
request.Source = val
@ -415,7 +415,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err
// for javascript protocol code references
for _, request := range template.RequestsJavascript {
// simple test to check if source is a file or a snippet
if len(strings.Split(request.Code, "\n")) == 1 && fileutil.FileExists(request.Code) {
if !strings.ContainsRune(request.Code, '\n') && fileutil.FileExists(request.Code) {
if val, ok := loadFile(request.Code); ok {
template.ImportedFiles = append(template.ImportedFiles, request.Code)
request.Code = val
@ -442,7 +442,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err
if req.Type() == types.CodeProtocol {
request := req.(*code.Request)
// simple test to check if source is a file or a snippet
if len(strings.Split(request.Source, "\n")) == 1 && fileutil.FileExists(request.Source) {
if !strings.ContainsRune(request.Source, '\n') && fileutil.FileExists(request.Source) {
if val, ok := loadFile(request.Source); ok {
template.ImportedFiles = append(template.ImportedFiles, request.Source)
request.Source = val
@ -456,7 +456,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err
if req.Type() == types.JavascriptProtocol {
request := req.(*javascript.Request)
// simple test to check if source is a file or a snippet
if len(strings.Split(request.Code, "\n")) == 1 && fileutil.FileExists(request.Code) {
if !strings.ContainsRune(request.Code, '\n') && fileutil.FileExists(request.Code) {
if val, ok := loadFile(request.Code); ok {
template.ImportedFiles = append(template.ImportedFiles, request.Code)
request.Code = val

View File

@ -9,6 +9,32 @@ import (
"gopkg.in/yaml.v2"
)
func TestCachePoolZeroing(t *testing.T) {
c := NewCache()
tpl := &Template{ID: "x"}
raw := []byte("SOME BIG RAW")
c.Store("id1", tpl, raw, nil)
gotTpl, gotErr := c.Get("id1")
if gotErr != nil {
t.Fatalf("unexpected err: %v", gotErr)
}
if gotTpl == nil || gotTpl.ID != "x" {
t.Fatalf("unexpected tpl: %#v", gotTpl)
}
// StoreWithoutRaw should not retain raw
c.StoreWithoutRaw("id2", tpl, nil)
gotTpl2, gotErr2 := c.Get("id2")
if gotErr2 != nil {
t.Fatalf("unexpected err: %v", gotErr2)
}
if gotTpl2 == nil || gotTpl2.ID != "x" {
t.Fatalf("unexpected tpl2: %#v", gotTpl2)
}
}
func TestTemplateStruct(t *testing.T) {
templatePath := "./tests/match-1.yaml"
bin, err := os.ReadFile(templatePath)