mirror of
https://github.com/projectdiscovery/nuclei.git
synced 2025-12-17 15:45:27 +00:00
cache, goroutine and unbounded workers management (#6420)
* Enhance matcher compilation with caching for regex and DSL expressions to improve performance. Update template parsing to conditionally retain raw templates based on size constraints. * Implement caching for regex and DSL expressions in extractors and matchers to enhance performance. Introduce a buffer pool in raw requests to reduce memory allocations. Update template cache management for improved efficiency. * feat: improve concurrency to be bound * refactor: replace fmt.Sprintf with fmt.Fprintf for improved performance in header handling * feat: add regex matching tests and benchmarks for performance evaluation * feat: add prefix check in regex extraction to optimize matching process * feat: implement regex caching mechanism to enhance performance in extractors and matchers, along with tests and benchmarks for validation * feat: add unit tests for template execution in the core engine, enhancing test coverage and reliability * feat: enhance error handling in template execution and improve regex caching logic for better performance * Implement caching for regex and DSL expressions in the cache package, replacing previous sync.Map usage. Add unit tests for cache functionality, including eviction by capacity and retrieval of cached items. Update extractors and matchers to utilize the new cache system for improved performance and memory efficiency. * Add tests for SetCapacities in cache package to ensure cache behavior on capacity changes - Implemented TestSetCapacities_NoRebuildOnZero to verify that setting capacities to zero does not clear existing caches. - Added TestSetCapacities_BeforeFirstUse to confirm that initial cache settings are respected and not overridden by subsequent capacity changes. * Refactor matchers and update load test generator to use io package - Removed maxRegexScanBytes constant from match.go. - Replaced ioutil with io package in load_test.go for NopCloser usage. - Restored TestValidate_AllowsInlineMultiline in load_test.go to ensure inline validation functionality. * Add cancellation support in template execution and enhance test coverage - Updated executeTemplateWithTargets to respect context cancellation. - Introduced fakeTargetProvider and slowExecuter for testing. - Added Test_executeTemplateWithTargets_RespectsCancellation to validate cancellation behavior during template execution.
This commit is contained in:
parent
d4f1a815ed
commit
c4fa2c74c1
@ -48,8 +48,15 @@ func (e *Engine) executeAllSelfContained(ctx context.Context, alltemplates []*te
|
||||
|
||||
// executeTemplateWithTargets executes a given template on x targets (with a internal targetpool(i.e concurrency))
|
||||
func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templates.Template, target provider.InputProvider, results *atomic.Bool) {
|
||||
// this is target pool i.e max target to execute
|
||||
wg := e.workPool.InputPool(template.Type())
|
||||
if e.workPool == nil {
|
||||
e.workPool = e.GetWorkPool()
|
||||
}
|
||||
// Bounded worker pool using input concurrency
|
||||
pool := e.workPool.InputPool(template.Type())
|
||||
workerCount := 1
|
||||
if pool != nil && pool.Size > 0 {
|
||||
workerCount = pool.Size
|
||||
}
|
||||
|
||||
var (
|
||||
index uint32
|
||||
@ -78,6 +85,41 @@ func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templ
|
||||
currentInfo.Unlock()
|
||||
}
|
||||
|
||||
// task represents a single target execution unit
|
||||
type task struct {
|
||||
index uint32
|
||||
skip bool
|
||||
value *contextargs.MetaInput
|
||||
}
|
||||
|
||||
tasks := make(chan task)
|
||||
var workersWg sync.WaitGroup
|
||||
workersWg.Add(workerCount)
|
||||
for i := 0; i < workerCount; i++ {
|
||||
go func() {
|
||||
defer workersWg.Done()
|
||||
for t := range tasks {
|
||||
func() {
|
||||
defer cleanupInFlight(t.index)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
if t.skip {
|
||||
return
|
||||
}
|
||||
|
||||
match, err := e.executeTemplateOnInput(ctx, template, t.value)
|
||||
if err != nil {
|
||||
e.options.Logger.Warning().Msgf("[%s] Could not execute step on %s: %s\n", e.executerOpts.Colorizer.BrightBlue(template.ID), t.value.Input, err)
|
||||
}
|
||||
results.CompareAndSwap(false, match)
|
||||
}()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
target.Iterate(func(scannedValue *contextargs.MetaInput) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@ -128,43 +170,13 @@ func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templ
|
||||
return true
|
||||
}
|
||||
|
||||
wg.Add()
|
||||
go func(index uint32, skip bool, value *contextargs.MetaInput) {
|
||||
defer wg.Done()
|
||||
defer cleanupInFlight(index)
|
||||
if skip {
|
||||
return
|
||||
}
|
||||
|
||||
var match bool
|
||||
var err error
|
||||
ctxArgs := contextargs.New(ctx)
|
||||
ctxArgs.MetaInput = value
|
||||
ctx := scan.NewScanContext(ctx, ctxArgs)
|
||||
switch template.Type() {
|
||||
case types.WorkflowProtocol:
|
||||
match = e.executeWorkflow(ctx, template.CompiledWorkflow)
|
||||
default:
|
||||
if e.Callback != nil {
|
||||
if results, err := template.Executer.ExecuteWithResults(ctx); err == nil {
|
||||
for _, result := range results {
|
||||
e.Callback(result)
|
||||
}
|
||||
}
|
||||
match = true
|
||||
} else {
|
||||
match, err = template.Executer.Execute(ctx)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
e.options.Logger.Warning().Msgf("[%s] Could not execute step on %s: %s\n", e.executerOpts.Colorizer.BrightBlue(template.ID), value.Input, err)
|
||||
}
|
||||
results.CompareAndSwap(false, match)
|
||||
}(index, skip, scannedValue)
|
||||
tasks <- task{index: index, skip: skip, value: scannedValue}
|
||||
index++
|
||||
return true
|
||||
})
|
||||
wg.Wait()
|
||||
|
||||
close(tasks)
|
||||
workersWg.Wait()
|
||||
|
||||
// on completion marks the template as completed
|
||||
currentInfo.Lock()
|
||||
@ -202,26 +214,7 @@ func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*t
|
||||
go func(template *templates.Template, value *contextargs.MetaInput, wg *syncutil.AdaptiveWaitGroup) {
|
||||
defer wg.Done()
|
||||
|
||||
var match bool
|
||||
var err error
|
||||
ctxArgs := contextargs.New(ctx)
|
||||
ctxArgs.MetaInput = value
|
||||
ctx := scan.NewScanContext(ctx, ctxArgs)
|
||||
switch template.Type() {
|
||||
case types.WorkflowProtocol:
|
||||
match = e.executeWorkflow(ctx, template.CompiledWorkflow)
|
||||
default:
|
||||
if e.Callback != nil {
|
||||
if results, err := template.Executer.ExecuteWithResults(ctx); err == nil {
|
||||
for _, result := range results {
|
||||
e.Callback(result)
|
||||
}
|
||||
}
|
||||
match = true
|
||||
} else {
|
||||
match, err = template.Executer.Execute(ctx)
|
||||
}
|
||||
}
|
||||
match, err := e.executeTemplateOnInput(ctx, template, value)
|
||||
if err != nil {
|
||||
e.options.Logger.Warning().Msgf("[%s] Could not execute step on %s: %s\n", e.executerOpts.Colorizer.BrightBlue(template.ID), value.Input, err)
|
||||
}
|
||||
@ -229,3 +222,27 @@ func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*t
|
||||
}(tpl, target, sg)
|
||||
}
|
||||
}
|
||||
|
||||
// executeTemplateOnInput performs template execution for a single input and returns match status and error
|
||||
func (e *Engine) executeTemplateOnInput(ctx context.Context, template *templates.Template, value *contextargs.MetaInput) (bool, error) {
|
||||
ctxArgs := contextargs.New(ctx)
|
||||
ctxArgs.MetaInput = value
|
||||
scanCtx := scan.NewScanContext(ctx, ctxArgs)
|
||||
|
||||
switch template.Type() {
|
||||
case types.WorkflowProtocol:
|
||||
return e.executeWorkflow(scanCtx, template.CompiledWorkflow), nil
|
||||
default:
|
||||
if e.Callback != nil {
|
||||
results, err := template.Executer.ExecuteWithResults(scanCtx)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
for _, result := range results {
|
||||
e.Callback(result)
|
||||
}
|
||||
return len(results) > 0, nil
|
||||
}
|
||||
return template.Executer.Execute(scanCtx)
|
||||
}
|
||||
}
|
||||
|
||||
148
pkg/core/executors_test.go
Normal file
148
pkg/core/executors_test.go
Normal file
@ -0,0 +1,148 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
inputtypes "github.com/projectdiscovery/nuclei/v3/pkg/input/types"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/output"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/scan"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/templates"
|
||||
tmpltypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/types"
|
||||
)
|
||||
|
||||
// fakeExecuter is a simple stub for protocols.Executer used to test executeTemplateOnInput
|
||||
type fakeExecuter struct {
|
||||
withResults bool
|
||||
}
|
||||
|
||||
func (f *fakeExecuter) Compile() error { return nil }
|
||||
func (f *fakeExecuter) Requests() int { return 1 }
|
||||
func (f *fakeExecuter) Execute(ctx *scan.ScanContext) (bool, error) { return !f.withResults, nil }
|
||||
func (f *fakeExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) {
|
||||
if !f.withResults {
|
||||
return nil, nil
|
||||
}
|
||||
return []*output.ResultEvent{{Host: "h"}}, nil
|
||||
}
|
||||
|
||||
// newTestEngine creates a minimal Engine for tests
|
||||
func newTestEngine() *Engine {
|
||||
return New(&types.Options{})
|
||||
}
|
||||
|
||||
func Test_executeTemplateOnInput_CallbackPath(t *testing.T) {
|
||||
e := newTestEngine()
|
||||
called := 0
|
||||
e.Callback = func(*output.ResultEvent) { called++ }
|
||||
|
||||
tpl := &templates.Template{}
|
||||
tpl.Executer = &fakeExecuter{withResults: true}
|
||||
|
||||
ok, err := e.executeTemplateOnInput(context.Background(), tpl, &contextargs.MetaInput{Input: "x"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("expected match true")
|
||||
}
|
||||
if called == 0 {
|
||||
t.Fatalf("expected callback to be called")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_executeTemplateOnInput_ExecutePath(t *testing.T) {
|
||||
e := newTestEngine()
|
||||
tpl := &templates.Template{}
|
||||
tpl.Executer = &fakeExecuter{withResults: false}
|
||||
|
||||
ok, err := e.executeTemplateOnInput(context.Background(), tpl, &contextargs.MetaInput{Input: "x"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatalf("expected match true from Execute path")
|
||||
}
|
||||
}
|
||||
|
||||
type fakeExecuterErr struct{}
|
||||
|
||||
func (f *fakeExecuterErr) Compile() error { return nil }
|
||||
func (f *fakeExecuterErr) Requests() int { return 1 }
|
||||
func (f *fakeExecuterErr) Execute(ctx *scan.ScanContext) (bool, error) { return false, nil }
|
||||
func (f *fakeExecuterErr) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) {
|
||||
return nil, fmt.Errorf("boom")
|
||||
}
|
||||
|
||||
func Test_executeTemplateOnInput_CallbackErrorPropagates(t *testing.T) {
|
||||
e := newTestEngine()
|
||||
e.Callback = func(*output.ResultEvent) {}
|
||||
tpl := &templates.Template{}
|
||||
tpl.Executer = &fakeExecuterErr{}
|
||||
|
||||
ok, err := e.executeTemplateOnInput(context.Background(), tpl, &contextargs.MetaInput{Input: "x"})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error to propagate")
|
||||
}
|
||||
if ok {
|
||||
t.Fatalf("expected match to be false on error")
|
||||
}
|
||||
}
|
||||
|
||||
type fakeTargetProvider struct {
|
||||
values []*contextargs.MetaInput
|
||||
}
|
||||
|
||||
func (f *fakeTargetProvider) Count() int64 { return int64(len(f.values)) }
|
||||
func (f *fakeTargetProvider) Iterate(cb func(value *contextargs.MetaInput) bool) {
|
||||
for _, v := range f.values {
|
||||
if !cb(v) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
func (f *fakeTargetProvider) Set(string, string) {}
|
||||
func (f *fakeTargetProvider) SetWithProbe(string, string, inputtypes.InputLivenessProbe) error {
|
||||
return nil
|
||||
}
|
||||
func (f *fakeTargetProvider) SetWithExclusions(string, string) error { return nil }
|
||||
func (f *fakeTargetProvider) InputType() string { return "test" }
|
||||
func (f *fakeTargetProvider) Close() {}
|
||||
|
||||
type slowExecuter struct{}
|
||||
|
||||
func (s *slowExecuter) Compile() error { return nil }
|
||||
func (s *slowExecuter) Requests() int { return 1 }
|
||||
func (s *slowExecuter) Execute(ctx *scan.ScanContext) (bool, error) {
|
||||
select {
|
||||
case <-ctx.Context().Done():
|
||||
return false, ctx.Context().Err()
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
func (s *slowExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func Test_executeTemplateWithTargets_RespectsCancellation(t *testing.T) {
|
||||
e := newTestEngine()
|
||||
e.SetExecuterOptions(&protocols.ExecutorOptions{Logger: e.Logger, ResumeCfg: types.NewResumeCfg(), ProtocolType: tmpltypes.HTTPProtocol})
|
||||
|
||||
tpl := &templates.Template{}
|
||||
tpl.Executer = &slowExecuter{}
|
||||
|
||||
targets := &fakeTargetProvider{values: []*contextargs.MetaInput{{Input: "a"}, {Input: "b"}, {Input: "c"}}}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
var matched atomic.Bool
|
||||
e.executeTemplateWithTargets(ctx, tpl, targets, &matched)
|
||||
}
|
||||
62
pkg/operators/cache/cache.go
vendored
Normal file
62
pkg/operators/cache/cache.go
vendored
Normal file
@ -0,0 +1,62 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sync"
|
||||
|
||||
"github.com/Knetic/govaluate"
|
||||
"github.com/projectdiscovery/gcache"
|
||||
)
|
||||
|
||||
var (
|
||||
initOnce sync.Once
|
||||
mu sync.RWMutex
|
||||
|
||||
regexCap = 4096
|
||||
dslCap = 4096
|
||||
|
||||
regexCache gcache.Cache[string, *regexp.Regexp]
|
||||
dslCache gcache.Cache[string, *govaluate.EvaluableExpression]
|
||||
)
|
||||
|
||||
func initCaches() {
|
||||
initOnce.Do(func() {
|
||||
regexCache = gcache.New[string, *regexp.Regexp](regexCap).LRU().Build()
|
||||
dslCache = gcache.New[string, *govaluate.EvaluableExpression](dslCap).LRU().Build()
|
||||
})
|
||||
}
|
||||
|
||||
func SetCapacities(regexCapacity, dslCapacity int) {
|
||||
// ensure caches are initialized under initOnce, so later Regex()/DSL() won't re-init
|
||||
initCaches()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if regexCapacity > 0 {
|
||||
regexCap = regexCapacity
|
||||
}
|
||||
if dslCapacity > 0 {
|
||||
dslCap = dslCapacity
|
||||
}
|
||||
if regexCapacity <= 0 && dslCapacity <= 0 {
|
||||
return
|
||||
}
|
||||
// rebuild caches with new capacities
|
||||
regexCache = gcache.New[string, *regexp.Regexp](regexCap).LRU().Build()
|
||||
dslCache = gcache.New[string, *govaluate.EvaluableExpression](dslCap).LRU().Build()
|
||||
}
|
||||
|
||||
func Regex() gcache.Cache[string, *regexp.Regexp] {
|
||||
initCaches()
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
return regexCache
|
||||
}
|
||||
|
||||
func DSL() gcache.Cache[string, *govaluate.EvaluableExpression] {
|
||||
initCaches()
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
return dslCache
|
||||
}
|
||||
114
pkg/operators/cache/cache_test.go
vendored
Normal file
114
pkg/operators/cache/cache_test.go
vendored
Normal file
@ -0,0 +1,114 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"testing"
|
||||
|
||||
"github.com/Knetic/govaluate"
|
||||
)
|
||||
|
||||
func TestRegexCache_SetGet(t *testing.T) {
|
||||
// ensure init
|
||||
c := Regex()
|
||||
pattern := "abc(\n)?123"
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
t.Fatalf("compile: %v", err)
|
||||
}
|
||||
if err := c.Set(pattern, re); err != nil {
|
||||
t.Fatalf("set: %v", err)
|
||||
}
|
||||
got, err := c.GetIFPresent(pattern)
|
||||
if err != nil || got == nil {
|
||||
t.Fatalf("get: %v got=%v", err, got)
|
||||
}
|
||||
if got.String() != re.String() {
|
||||
t.Fatalf("mismatch: %s != %s", got.String(), re.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSLCache_SetGet(t *testing.T) {
|
||||
c := DSL()
|
||||
expr := "1 + 2 == 3"
|
||||
ast, err := govaluate.NewEvaluableExpression(expr)
|
||||
if err != nil {
|
||||
t.Fatalf("dsl compile: %v", err)
|
||||
}
|
||||
if err := c.Set(expr, ast); err != nil {
|
||||
t.Fatalf("set: %v", err)
|
||||
}
|
||||
got, err := c.GetIFPresent(expr)
|
||||
if err != nil || got == nil {
|
||||
t.Fatalf("get: %v got=%v", err, got)
|
||||
}
|
||||
if got.String() != ast.String() {
|
||||
t.Fatalf("mismatch: %s != %s", got.String(), ast.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegexCache_EvictionByCapacity(t *testing.T) {
|
||||
SetCapacities(3, 3)
|
||||
c := Regex()
|
||||
for i := 0; i < 5; i++ {
|
||||
k := string(rune('a' + i))
|
||||
re := regexp.MustCompile(k)
|
||||
_ = c.Set(k, re)
|
||||
}
|
||||
// last 3 keys expected to remain under LRU: 'c','d','e'
|
||||
if _, err := c.GetIFPresent("a"); err == nil {
|
||||
t.Fatalf("expected 'a' to be evicted")
|
||||
}
|
||||
if _, err := c.GetIFPresent("b"); err == nil {
|
||||
t.Fatalf("expected 'b' to be evicted")
|
||||
}
|
||||
if _, err := c.GetIFPresent("c"); err != nil {
|
||||
t.Fatalf("expected 'c' present")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCapacities_NoRebuildOnZero(t *testing.T) {
|
||||
// init
|
||||
SetCapacities(4, 4)
|
||||
c1 := Regex()
|
||||
_ = c1.Set("k", regexp.MustCompile("k"))
|
||||
if _, err := c1.GetIFPresent("k"); err != nil {
|
||||
t.Fatalf("expected key present: %v", err)
|
||||
}
|
||||
// zero changes should not rebuild/clear caches
|
||||
SetCapacities(0, 0)
|
||||
c2 := Regex()
|
||||
if _, err := c2.GetIFPresent("k"); err != nil {
|
||||
t.Fatalf("key lost after zero-capacity SetCapacities: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCapacities_BeforeFirstUse(t *testing.T) {
|
||||
// This should not be overridden by later initCaches
|
||||
SetCapacities(2, 0)
|
||||
c := Regex()
|
||||
_ = c.Set("a", regexp.MustCompile("a"))
|
||||
_ = c.Set("b", regexp.MustCompile("b"))
|
||||
_ = c.Set("c", regexp.MustCompile("c"))
|
||||
if _, err := c.GetIFPresent("a"); err == nil {
|
||||
t.Fatalf("expected 'a' to be evicted under cap=2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCapacities_ConcurrentAccess(t *testing.T) {
|
||||
SetCapacities(64, 64)
|
||||
stop := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
for i := 0; i < 5000; i++ {
|
||||
_ = Regex().Set("k"+string(rune('a'+(i%26))), regexp.MustCompile("a"))
|
||||
_, _ = Regex().GetIFPresent("k" + string(rune('a'+(i%26))))
|
||||
_, _ = DSL().GetIFPresent("1+2==3")
|
||||
}
|
||||
close(stop)
|
||||
}()
|
||||
|
||||
for i := 0; i < 200; i++ {
|
||||
SetCapacities(64+(i%5), 64+((i+1)%5))
|
||||
}
|
||||
<-stop
|
||||
}
|
||||
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/Knetic/govaluate"
|
||||
"github.com/itchyny/gojq"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/operators/cache"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/operators/common/dsl"
|
||||
)
|
||||
|
||||
@ -20,10 +21,15 @@ func (e *Extractor) CompileExtractors() error {
|
||||
e.extractorType = computedType
|
||||
// Compile the regexes
|
||||
for _, regex := range e.Regex {
|
||||
if cached, err := cache.Regex().GetIFPresent(regex); err == nil && cached != nil {
|
||||
e.regexCompiled = append(e.regexCompiled, cached)
|
||||
continue
|
||||
}
|
||||
compiled, err := regexp.Compile(regex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not compile regex: %s", regex)
|
||||
}
|
||||
_ = cache.Regex().Set(regex, compiled)
|
||||
e.regexCompiled = append(e.regexCompiled, compiled)
|
||||
}
|
||||
for i, kval := range e.KVal {
|
||||
@ -43,10 +49,15 @@ func (e *Extractor) CompileExtractors() error {
|
||||
}
|
||||
|
||||
for _, dslExp := range e.DSL {
|
||||
if cached, err := cache.DSL().GetIFPresent(dslExp); err == nil && cached != nil {
|
||||
e.dslCompiled = append(e.dslCompiled, cached)
|
||||
continue
|
||||
}
|
||||
compiled, err := govaluate.NewEvaluableExpressionWithFunctions(dslExp, dsl.HelperFunctions)
|
||||
if err != nil {
|
||||
return &dsl.CompilationError{DslSignature: dslExp, WrappedError: err}
|
||||
}
|
||||
_ = cache.DSL().Set(dslExp, compiled)
|
||||
e.dslCompiled = append(e.dslCompiled, compiled)
|
||||
}
|
||||
|
||||
|
||||
@ -17,9 +17,19 @@ func (e *Extractor) ExtractRegex(corpus string) map[string]struct{} {
|
||||
|
||||
groupPlusOne := e.RegexGroup + 1
|
||||
for _, regex := range e.regexCompiled {
|
||||
matches := regex.FindAllStringSubmatch(corpus, -1)
|
||||
// skip prefix short-circuit for case-insensitive patterns
|
||||
rstr := regex.String()
|
||||
if !strings.Contains(rstr, "(?i") {
|
||||
if prefix, ok := regex.LiteralPrefix(); ok && prefix != "" {
|
||||
if !strings.Contains(corpus, prefix) {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, match := range matches {
|
||||
submatches := regex.FindAllStringSubmatch(corpus, -1)
|
||||
|
||||
for _, match := range submatches {
|
||||
if len(match) < groupPlusOne {
|
||||
continue
|
||||
}
|
||||
|
||||
@ -7,7 +7,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/Knetic/govaluate"
|
||||
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/operators/cache"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/operators/common/dsl"
|
||||
)
|
||||
|
||||
@ -42,12 +42,17 @@ func (matcher *Matcher) CompileMatchers() error {
|
||||
matcher.Part = "body"
|
||||
}
|
||||
|
||||
// Compile the regexes
|
||||
// Compile the regexes (with shared cache)
|
||||
for _, regex := range matcher.Regex {
|
||||
if cached, err := cache.Regex().GetIFPresent(regex); err == nil && cached != nil {
|
||||
matcher.regexCompiled = append(matcher.regexCompiled, cached)
|
||||
continue
|
||||
}
|
||||
compiled, err := regexp.Compile(regex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not compile regex: %s", regex)
|
||||
}
|
||||
_ = cache.Regex().Set(regex, compiled)
|
||||
matcher.regexCompiled = append(matcher.regexCompiled, compiled)
|
||||
}
|
||||
|
||||
@ -60,12 +65,17 @@ func (matcher *Matcher) CompileMatchers() error {
|
||||
}
|
||||
}
|
||||
|
||||
// Compile the dsl expressions
|
||||
// Compile the dsl expressions (with shared cache)
|
||||
for _, dslExpression := range matcher.DSL {
|
||||
if cached, err := cache.DSL().GetIFPresent(dslExpression); err == nil && cached != nil {
|
||||
matcher.dslCompiled = append(matcher.dslCompiled, cached)
|
||||
continue
|
||||
}
|
||||
compiledExpression, err := govaluate.NewEvaluableExpressionWithFunctions(dslExpression, dsl.HelperFunctions)
|
||||
if err != nil {
|
||||
return &dsl.CompilationError{DslSignature: dslExpression, WrappedError: err}
|
||||
}
|
||||
_ = cache.DSL().Set(dslExpression, compiledExpression)
|
||||
matcher.dslCompiled = append(matcher.dslCompiled, compiledExpression)
|
||||
}
|
||||
|
||||
|
||||
@ -106,10 +106,33 @@ func (matcher *Matcher) MatchRegex(corpus string) (bool, []string) {
|
||||
var matchedRegexes []string
|
||||
// Iterate over all the regexes accepted as valid
|
||||
for i, regex := range matcher.regexCompiled {
|
||||
// Continue if the regex doesn't match
|
||||
if !regex.MatchString(corpus) {
|
||||
// If we are in an AND request and a match failed,
|
||||
// return false as the AND condition fails on any single mismatch.
|
||||
// Literal prefix short-circuit
|
||||
rstr := regex.String()
|
||||
if !strings.Contains(rstr, "(?i") { // covers (?i) and (?i:
|
||||
if prefix, ok := regex.LiteralPrefix(); ok && prefix != "" {
|
||||
if !strings.Contains(corpus, prefix) {
|
||||
switch matcher.condition {
|
||||
case ANDCondition:
|
||||
return false, []string{}
|
||||
case ORCondition:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fast OR-path: return first match without full scan
|
||||
if matcher.condition == ORCondition && !matcher.MatchAll {
|
||||
m := regex.FindAllString(corpus, 1)
|
||||
if len(m) == 0 {
|
||||
continue
|
||||
}
|
||||
return true, m
|
||||
}
|
||||
|
||||
// Single scan: get all matches directly
|
||||
currentMatches := regex.FindAllString(corpus, -1)
|
||||
if len(currentMatches) == 0 {
|
||||
switch matcher.condition {
|
||||
case ANDCondition:
|
||||
return false, []string{}
|
||||
@ -118,12 +141,7 @@ func (matcher *Matcher) MatchRegex(corpus string) (bool, []string) {
|
||||
}
|
||||
}
|
||||
|
||||
currentMatches := regex.FindAllString(corpus, -1)
|
||||
// If the condition was an OR, return on the first match.
|
||||
if matcher.condition == ORCondition && !matcher.MatchAll {
|
||||
return true, currentMatches
|
||||
}
|
||||
|
||||
// If the condition was an OR (and MatchAll true), we still need to gather all
|
||||
matchedRegexes = append(matchedRegexes, currentMatches...)
|
||||
|
||||
// If we are at the end of the regex, return with true
|
||||
|
||||
@ -84,7 +84,7 @@ func TestMatcher_MatchDSL(t *testing.T) {
|
||||
|
||||
values := []string{"PING", "pong"}
|
||||
|
||||
for value := range values {
|
||||
for _, value := range values {
|
||||
isMatched := m.MatchDSL(map[string]interface{}{"body": value, "VARIABLE": value})
|
||||
require.True(t, isMatched)
|
||||
}
|
||||
@ -209,3 +209,66 @@ func TestMatcher_MatchXPath_XML(t *testing.T) {
|
||||
isMatched = m.MatchXPath("<h1> not right <q id=2/>notvalid")
|
||||
require.False(t, isMatched, "Invalid xpath did not return false")
|
||||
}
|
||||
|
||||
func TestMatchRegex_CaseInsensitivePrefixSkip(t *testing.T) {
|
||||
m := &Matcher{Type: MatcherTypeHolder{MatcherType: RegexMatcher}, Condition: "or", Regex: []string{"(?i)abc"}}
|
||||
err := m.CompileMatchers()
|
||||
require.NoError(t, err)
|
||||
ok, got := m.MatchRegex("zzz AbC yyy")
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, got)
|
||||
}
|
||||
|
||||
func TestMatchStatusCodeAndSize(t *testing.T) {
|
||||
mStatus := &Matcher{Status: []int{200, 302}}
|
||||
require.True(t, mStatus.MatchStatusCode(200))
|
||||
require.True(t, mStatus.MatchStatusCode(302))
|
||||
require.False(t, mStatus.MatchStatusCode(404))
|
||||
|
||||
mSize := &Matcher{Size: []int{5, 10}}
|
||||
require.True(t, mSize.MatchSize(5))
|
||||
require.False(t, mSize.MatchSize(7))
|
||||
}
|
||||
|
||||
func TestMatchBinary_AND_OR(t *testing.T) {
|
||||
// AND should fail if any binary not present
|
||||
mAnd := &Matcher{Type: MatcherTypeHolder{MatcherType: BinaryMatcher}, Condition: "and", Binary: []string{"50494e47", "414141"}} // "PING", "AAA"
|
||||
require.NoError(t, mAnd.CompileMatchers())
|
||||
ok, _ := mAnd.MatchBinary("PING")
|
||||
require.False(t, ok)
|
||||
// OR should succeed if any present
|
||||
mOr := &Matcher{Type: MatcherTypeHolder{MatcherType: BinaryMatcher}, Condition: "or", Binary: []string{"414141", "50494e47"}} // "AAA", "PING"
|
||||
require.NoError(t, mOr.CompileMatchers())
|
||||
ok, got := mOr.MatchBinary("xxPINGyy")
|
||||
require.True(t, ok)
|
||||
require.NotEmpty(t, got)
|
||||
}
|
||||
|
||||
func TestMatchRegex_LiteralPrefixShortCircuit(t *testing.T) {
|
||||
// AND: first regex has literal prefix "abc"; corpus lacks it => early false
|
||||
mAnd := &Matcher{Type: MatcherTypeHolder{MatcherType: RegexMatcher}, Condition: "and", Regex: []string{"abc[0-9]*", "[0-9]{2}"}}
|
||||
require.NoError(t, mAnd.CompileMatchers())
|
||||
ok, matches := mAnd.MatchRegex("zzz 12 yyy")
|
||||
require.False(t, ok)
|
||||
require.Empty(t, matches)
|
||||
|
||||
// OR: first regex skipped due to missing prefix, second matches => true
|
||||
mOr := &Matcher{Type: MatcherTypeHolder{MatcherType: RegexMatcher}, Condition: "or", Regex: []string{"abc[0-9]*", "[0-9]{2}"}}
|
||||
require.NoError(t, mOr.CompileMatchers())
|
||||
ok, matches = mOr.MatchRegex("zzz 12 yyy")
|
||||
require.True(t, ok)
|
||||
require.Equal(t, []string{"12"}, matches)
|
||||
}
|
||||
|
||||
func TestMatcher_MatchDSL_ErrorHandling(t *testing.T) {
|
||||
// First expression errors (division by zero), second is true
|
||||
bad, err := govaluate.NewEvaluableExpression("1 / 0")
|
||||
require.NoError(t, err)
|
||||
good, err := govaluate.NewEvaluableExpression("1 == 1")
|
||||
require.NoError(t, err)
|
||||
|
||||
m := &Matcher{Type: MatcherTypeHolder{MatcherType: DSLMatcher}, Condition: "or", dslCompiled: []*govaluate.EvaluableExpression{bad, good}}
|
||||
require.NoError(t, m.CompileMatchers())
|
||||
ok := m.MatchDSL(map[string]interface{}{})
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
26
pkg/protocols/common/generators/attack_types_test.go
Normal file
26
pkg/protocols/common/generators/attack_types_test.go
Normal file
@ -0,0 +1,26 @@
|
||||
package generators
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAttackTypeHelpers(t *testing.T) {
|
||||
// GetSupportedAttackTypes should include three values
|
||||
types := GetSupportedAttackTypes()
|
||||
if len(types) != 3 {
|
||||
t.Fatalf("expected 3 types, got %d", len(types))
|
||||
}
|
||||
// toAttackType valid
|
||||
if got, err := toAttackType("pitchfork"); err != nil || got != PitchForkAttack {
|
||||
t.Fatalf("toAttackType failed: %v %v", got, err)
|
||||
}
|
||||
// toAttackType invalid
|
||||
if _, err := toAttackType("nope"); err == nil {
|
||||
t.Fatalf("expected error for invalid attack type")
|
||||
}
|
||||
// normalizeValue and String
|
||||
if normalizeValue(" ClusterBomb ") != "clusterbomb" {
|
||||
t.Fatalf("normalizeValue failed")
|
||||
}
|
||||
if ClusterBombAttack.String() != "clusterbomb" {
|
||||
t.Fatalf("String failed")
|
||||
}
|
||||
}
|
||||
38
pkg/protocols/common/generators/env_test.go
Normal file
38
pkg/protocols/common/generators/env_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package generators
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseEnvVars(t *testing.T) {
|
||||
old := os.Environ()
|
||||
// set a scoped env var
|
||||
_ = os.Setenv("NUCLEI_TEST_K", "V1")
|
||||
t.Cleanup(func() {
|
||||
// restore
|
||||
for _, kv := range old {
|
||||
parts := kv
|
||||
_ = parts // nothing, environment already has superset; best-effort cleanup below
|
||||
}
|
||||
_ = os.Unsetenv("NUCLEI_TEST_K")
|
||||
})
|
||||
vars := parseEnvVars()
|
||||
if vars["NUCLEI_TEST_K"] != "V1" {
|
||||
t.Fatalf("expected V1, got %v", vars["NUCLEI_TEST_K"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvVarsMemoization(t *testing.T) {
|
||||
// reset memoized map
|
||||
envVars = nil
|
||||
_ = os.Setenv("NUCLEI_TEST_MEMO", "A")
|
||||
t.Cleanup(func() { _ = os.Unsetenv("NUCLEI_TEST_MEMO") })
|
||||
v1 := EnvVars()["NUCLEI_TEST_MEMO"]
|
||||
// change env after memoization
|
||||
_ = os.Setenv("NUCLEI_TEST_MEMO", "B")
|
||||
v2 := EnvVars()["NUCLEI_TEST_MEMO"]
|
||||
if v1 != "A" || v2 != "A" {
|
||||
t.Fatalf("memoization failed: %v %v", v1, v2)
|
||||
}
|
||||
}
|
||||
@ -17,6 +17,20 @@ func (generator *PayloadGenerator) loadPayloads(payloads map[string]interface{},
|
||||
for name, payload := range payloads {
|
||||
switch pt := payload.(type) {
|
||||
case string:
|
||||
// Fast path: if no newline, treat as file path
|
||||
if !strings.ContainsRune(pt, '\n') {
|
||||
file, err := generator.options.LoadHelperFile(pt, templatePath, generator.catalog)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not load payload file")
|
||||
}
|
||||
payloads, err := generator.loadPayloadsFromFile(file)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not load payloads")
|
||||
}
|
||||
loadedPayloads[name] = payloads
|
||||
break
|
||||
}
|
||||
// Multiline inline payloads
|
||||
elements := strings.Split(pt, "\n")
|
||||
//golint:gomnd // this is not a magic number
|
||||
if len(elements) >= 2 {
|
||||
|
||||
@ -1,120 +1,108 @@
|
||||
package generators
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/disk"
|
||||
osutils "github.com/projectdiscovery/utils/os"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/catalog"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/types"
|
||||
)
|
||||
|
||||
func TestLoadPayloads(t *testing.T) {
|
||||
// since we are changing value of global variable i.e templates directory
|
||||
// run this test as subprocess
|
||||
if os.Getenv("LOAD_PAYLOAD_NO_ACCESS") != "1" {
|
||||
cmd := exec.Command(os.Args[0], "-test.run=TestLoadPayloadsWithAccess")
|
||||
cmd.Env = append(os.Environ(), "LOAD_PAYLOAD_NO_ACCESS=1")
|
||||
err := cmd.Run()
|
||||
if e, ok := err.(*exec.ExitError); ok && !e.Success() {
|
||||
return
|
||||
type fakeCatalog struct{ catalog.Catalog }
|
||||
|
||||
func (f *fakeCatalog) OpenFile(filename string) (io.ReadCloser, error) {
|
||||
return nil, errors.New("not used")
|
||||
}
|
||||
func (f *fakeCatalog) GetTemplatePath(target string) ([]string, error) { return nil, nil }
|
||||
func (f *fakeCatalog) GetTemplatesPath(definitions []string) ([]string, map[string]error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (f *fakeCatalog) ResolvePath(templateName, second string) (string, error) {
|
||||
return templateName, nil
|
||||
}
|
||||
|
||||
func newTestGenerator() *PayloadGenerator {
|
||||
opts := types.DefaultOptions()
|
||||
// inject helper loader function
|
||||
opts.LoadHelperFileFunction = func(path, templatePath string, _ catalog.Catalog) (io.ReadCloser, error) {
|
||||
switch path {
|
||||
case "fileA.txt":
|
||||
return io.NopCloser(strings.NewReader("one\n two\n\nthree\n")), nil
|
||||
default:
|
||||
return io.NopCloser(strings.NewReader("x\ny\nz\n")), nil
|
||||
}
|
||||
}
|
||||
return &PayloadGenerator{options: opts, catalog: &fakeCatalog{}}
|
||||
}
|
||||
|
||||
func TestLoadPayloads_FastPathFile(t *testing.T) {
|
||||
g := newTestGenerator()
|
||||
out, err := g.loadPayloads(map[string]interface{}{"A": "fileA.txt"}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("process ran with err %v, want exit status 1", err)
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
got := out["A"]
|
||||
if len(got) != 3 || got[0] != "one" || got[1] != " two" || got[2] != "three" {
|
||||
t.Fatalf("unexpected: %#v", got)
|
||||
}
|
||||
templateDir := getTemplatesDir(t)
|
||||
config.DefaultConfig.SetTemplatesDir(templateDir)
|
||||
|
||||
generator := &PayloadGenerator{catalog: disk.NewCatalog(templateDir), options: getOptions(false)}
|
||||
fullpath := filepath.Join(templateDir, "payloads.txt")
|
||||
|
||||
// Test sandbox
|
||||
t.Run("templates-directory", func(t *testing.T) {
|
||||
// testcase when loading file from template directory and template file is in root
|
||||
// expected to succeed
|
||||
values, err := generator.loadPayloads(map[string]interface{}{
|
||||
"new": fullpath,
|
||||
}, "/test")
|
||||
require.NoError(t, err, "could not load payloads")
|
||||
require.Equal(t, map[string][]string{"new": {"test", "another"}}, values, "could not get values")
|
||||
})
|
||||
t.Run("templates-path-relative", func(t *testing.T) {
|
||||
// testcase when loading file from template directory and template file is current working directory
|
||||
// expected to fail since this is LFI
|
||||
_, err := generator.loadPayloads(map[string]interface{}{
|
||||
"new": "../../../../../../../../../etc/passwd",
|
||||
}, ".")
|
||||
require.Error(t, err, "could load payloads")
|
||||
})
|
||||
t.Run("template-directory", func(t *testing.T) {
|
||||
// testcase when loading file from template directory and template file is inside template directory
|
||||
// expected to succeed
|
||||
values, err := generator.loadPayloads(map[string]interface{}{
|
||||
"new": fullpath,
|
||||
}, filepath.Join(templateDir, "test.yaml"))
|
||||
require.NoError(t, err, "could not load payloads")
|
||||
require.Equal(t, map[string][]string{"new": {"test", "another"}}, values, "could not get values")
|
||||
})
|
||||
|
||||
t.Run("invalid", func(t *testing.T) {
|
||||
// testcase when loading file from /etc/passwd and template file is at root i.e /
|
||||
// expected to fail since this is LFI
|
||||
values, err := generator.loadPayloads(map[string]interface{}{
|
||||
"new": "/etc/passwd",
|
||||
}, "/random")
|
||||
require.Error(t, err, "could load payloads got %v", values)
|
||||
require.Equal(t, 0, len(values), "could get values")
|
||||
|
||||
// testcase when loading file from template directory and template file is at root i.e /
|
||||
// expected to succeed
|
||||
values, err = generator.loadPayloads(map[string]interface{}{
|
||||
"new": fullpath,
|
||||
}, "/random")
|
||||
require.NoError(t, err, "could load payloads %v", values)
|
||||
require.Equal(t, 1, len(values), "could get values")
|
||||
require.Equal(t, []string{"test", "another"}, values["new"], "could get values")
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoadPayloadsWithAccess(t *testing.T) {
|
||||
// since we are changing value of global variable i.e templates directory
|
||||
// run this test as subprocess
|
||||
if os.Getenv("LOAD_PAYLOAD_WITH_ACCESS") != "1" {
|
||||
cmd := exec.Command(os.Args[0], "-test.run=TestLoadPayloadsWithAccess")
|
||||
cmd.Env = append(os.Environ(), "LOAD_PAYLOAD_WITH_ACCESS=1")
|
||||
err := cmd.Run()
|
||||
if e, ok := err.(*exec.ExitError); ok && !e.Success() {
|
||||
return
|
||||
}
|
||||
func TestLoadPayloads_InlineMultiline(t *testing.T) {
|
||||
g := newTestGenerator()
|
||||
inline := "a\nb\n"
|
||||
out, err := g.loadPayloads(map[string]interface{}{"B": inline}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("process ran with err %v, want exit status 1", err)
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
got := out["B"]
|
||||
if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "" {
|
||||
t.Fatalf("unexpected: %#v", got)
|
||||
}
|
||||
templateDir := getTemplatesDir(t)
|
||||
config.DefaultConfig.SetTemplatesDir(templateDir)
|
||||
|
||||
generator := &PayloadGenerator{catalog: disk.NewCatalog(templateDir), options: getOptions(true)}
|
||||
|
||||
t.Run("no-sandbox-unix", func(t *testing.T) {
|
||||
if osutils.IsWindows() {
|
||||
return
|
||||
}
|
||||
_, err := generator.loadPayloads(map[string]interface{}{
|
||||
"new": "/etc/passwd",
|
||||
}, "/random")
|
||||
require.NoError(t, err, "could load payloads")
|
||||
})
|
||||
}
|
||||
|
||||
func getTemplatesDir(t *testing.T) string {
|
||||
tempdir, err := os.MkdirTemp("", "templates-*")
|
||||
require.NoError(t, err, "could not create temp dir")
|
||||
fullpath := filepath.Join(tempdir, "payloads.txt")
|
||||
err = os.WriteFile(fullpath, []byte("test\nanother"), 0777)
|
||||
require.NoError(t, err, "could not write payload")
|
||||
return tempdir
|
||||
func TestLoadPayloads_SingleLineFallsBackToFile(t *testing.T) {
|
||||
g := newTestGenerator()
|
||||
inline := "fileA.txt" // single line, should be treated as file path
|
||||
out, err := g.loadPayloads(map[string]interface{}{"C": inline}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
got := out["C"]
|
||||
if len(got) != 3 {
|
||||
t.Fatalf("unexpected len: %d", len(got))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadPayloads_InterfaceSlice(t *testing.T) {
|
||||
g := newTestGenerator()
|
||||
out, err := g.loadPayloads(map[string]interface{}{"D": []interface{}{"p", "q"}}, "")
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
got := out["D"]
|
||||
if len(got) != 2 || got[0] != "p" || got[1] != "q" {
|
||||
t.Fatalf("unexpected: %#v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadPayloadsFromFile_SkipsEmpty(t *testing.T) {
|
||||
g := newTestGenerator()
|
||||
rc := io.NopCloser(strings.NewReader("a\n\n\n b \n"))
|
||||
lines, err := g.loadPayloadsFromFile(rc)
|
||||
if err != nil {
|
||||
t.Fatalf("err: %v", err)
|
||||
}
|
||||
if len(lines) != 2 || lines[0] != "a" || lines[1] != " b " {
|
||||
t.Fatalf("unexpected: %#v", lines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_AllowsInlineMultiline(t *testing.T) {
|
||||
g := newTestGenerator()
|
||||
inline := "x\ny\n"
|
||||
if err := g.validate(map[string]interface{}{"E": inline}, ""); err != nil {
|
||||
t.Fatalf("validate rejected inline multiline: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -14,3 +14,32 @@ func TestMergeMapsMany(t *testing.T) {
|
||||
"c": {"5"},
|
||||
}, got, "could not get correct merged map")
|
||||
}
|
||||
|
||||
func TestMergeMapsAndExpand(t *testing.T) {
|
||||
m1 := map[string]interface{}{"a": "1"}
|
||||
m2 := map[string]interface{}{"b": "2"}
|
||||
out := MergeMaps(m1, m2)
|
||||
if out["a"].(string) != "1" || out["b"].(string) != "2" {
|
||||
t.Fatalf("unexpected merge: %#v", out)
|
||||
}
|
||||
flat := map[string]string{"x": "y"}
|
||||
exp := ExpandMapValues(flat)
|
||||
if len(exp["x"]) != 1 || exp["x"][0] != "y" {
|
||||
t.Fatalf("unexpected expand: %#v", exp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIteratorRemaining(t *testing.T) {
|
||||
g, err := New(map[string]interface{}{"k": []interface{}{"a", "b"}}, BatteringRamAttack, "", nil, "", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("new: %v", err)
|
||||
}
|
||||
it := g.NewIterator()
|
||||
if it.Total() != 2 || it.Remaining() != 2 {
|
||||
t.Fatalf("unexpected totals: %d %d", it.Total(), it.Remaining())
|
||||
}
|
||||
_, _ = it.Value()
|
||||
if it.Remaining() != 1 {
|
||||
t.Fatalf("unexpected remaining after one: %d", it.Remaining())
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
package generators
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@ -17,9 +16,8 @@ func (g *PayloadGenerator) validate(payloads map[string]interface{}, templatePat
|
||||
for name, payload := range payloads {
|
||||
switch payloadType := payload.(type) {
|
||||
case string:
|
||||
// check if it's a multiline string list
|
||||
if len(strings.Split(payloadType, "\n")) != 1 {
|
||||
return errors.New("invalid number of lines in payload")
|
||||
if strings.ContainsRune(payloadType, '\n') {
|
||||
continue
|
||||
}
|
||||
|
||||
// For historical reasons, "validate" checks to see if the payload file exist.
|
||||
|
||||
@ -92,9 +92,8 @@ type generatedRequest struct {
|
||||
|
||||
// setReqURLPattern sets the url request pattern for the generated request
|
||||
func (gr *generatedRequest) setReqURLPattern(reqURLPattern string) {
|
||||
data := strings.Split(reqURLPattern, "\n")
|
||||
if len(data) > 1 {
|
||||
reqURLPattern = strings.TrimSpace(data[0])
|
||||
if idx := strings.IndexByte(reqURLPattern, '\n'); idx >= 0 {
|
||||
reqURLPattern = strings.TrimSpace(reqURLPattern[:idx])
|
||||
// this is raw request (if it has 3 parts after strings.Fields then its valid only use 2nd part)
|
||||
parts := strings.Fields(reqURLPattern)
|
||||
if len(parts) >= 3 {
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/projectdiscovery/gologger"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/authprovider/authx"
|
||||
@ -17,6 +18,8 @@ import (
|
||||
urlutil "github.com/projectdiscovery/utils/url"
|
||||
)
|
||||
|
||||
var bufferPool = sync.Pool{New: func() any { return new(bytes.Buffer) }}
|
||||
|
||||
// Request defines a basic HTTP raw request
|
||||
type Request struct {
|
||||
FullURL string
|
||||
@ -270,13 +273,17 @@ func (r *Request) TryFillCustomHeaders(headers []string) error {
|
||||
if newLineIndex > 0 {
|
||||
newLineIndex += hostHeaderIndex + 2
|
||||
// insert custom headers
|
||||
var buf bytes.Buffer
|
||||
buf := bufferPool.Get().(*bytes.Buffer)
|
||||
buf.Reset()
|
||||
buf.Write(r.UnsafeRawBytes[:newLineIndex])
|
||||
for _, header := range headers {
|
||||
buf.WriteString(fmt.Sprintf("%s\r\n", header))
|
||||
buf.WriteString(header)
|
||||
buf.WriteString("\r\n")
|
||||
}
|
||||
buf.Write(r.UnsafeRawBytes[newLineIndex:])
|
||||
r.UnsafeRawBytes = buf.Bytes()
|
||||
r.UnsafeRawBytes = append([]byte(nil), buf.Bytes()...)
|
||||
buf.Reset()
|
||||
bufferPool.Put(buf)
|
||||
return nil
|
||||
}
|
||||
return errors.New("no new line found at the end of host header")
|
||||
@ -301,9 +308,10 @@ func (r *Request) ApplyAuthStrategy(strategy authx.AuthStrategy) {
|
||||
parsed.Params.Add(p.Key, p.Value)
|
||||
}
|
||||
case *authx.CookiesAuthStrategy:
|
||||
var buff bytes.Buffer
|
||||
buff := bufferPool.Get().(*bytes.Buffer)
|
||||
buff.Reset()
|
||||
for _, cookie := range s.Data.Cookies {
|
||||
buff.WriteString(fmt.Sprintf("%s=%s; ", cookie.Key, cookie.Value))
|
||||
fmt.Fprintf(buff, "%s=%s; ", cookie.Key, cookie.Value)
|
||||
}
|
||||
if buff.Len() > 0 {
|
||||
if val, ok := r.Headers["Cookie"]; ok {
|
||||
@ -312,6 +320,7 @@ func (r *Request) ApplyAuthStrategy(strategy authx.AuthStrategy) {
|
||||
r.Headers["Cookie"] = buff.String()
|
||||
}
|
||||
}
|
||||
bufferPool.Put(buff)
|
||||
case *authx.HeadersAuthStrategy:
|
||||
for _, header := range s.Data.Headers {
|
||||
r.Headers[header.Key] = header.Value
|
||||
|
||||
@ -7,6 +7,21 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTryFillCustomHeaders_BufferDetached(t *testing.T) {
|
||||
r := &Request{
|
||||
UnsafeRawBytes: []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\nBody"),
|
||||
}
|
||||
// first fill
|
||||
err := r.TryFillCustomHeaders([]string{"X-Test: 1"})
|
||||
require.NoError(t, err, "unexpected error on first call")
|
||||
prev := r.UnsafeRawBytes
|
||||
prevStr := string(prev) // content snapshot
|
||||
err = r.TryFillCustomHeaders([]string{"X-Another: 2"})
|
||||
require.NoError(t, err, "unexpected error on second call")
|
||||
require.Equal(t, prevStr, string(prev), "first slice mutated after second call; buffer not detached")
|
||||
require.NotEqual(t, prevStr, string(r.UnsafeRawBytes), "request bytes did not change after second call")
|
||||
}
|
||||
|
||||
func TestParseRawRequestWithPort(t *testing.T) {
|
||||
request, err := Parse(`GET /gg/phpinfo.php HTTP/1.1
|
||||
Host: {{Hostname}}:123
|
||||
|
||||
@ -240,6 +240,48 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
|
||||
}
|
||||
})
|
||||
|
||||
// bounded worker-pool to avoid spawning one goroutine per payload
|
||||
type task struct {
|
||||
req *generatedRequest
|
||||
updatedInput *contextargs.Context
|
||||
}
|
||||
|
||||
var workersWg sync.WaitGroup
|
||||
currentWorkers := maxWorkers
|
||||
tasks := make(chan task, maxWorkers)
|
||||
spawnWorker := func(ctx context.Context) {
|
||||
workersWg.Add(1)
|
||||
go func() {
|
||||
defer workersWg.Done()
|
||||
for t := range tasks {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(t.updatedInput) || spmHandler.Cancelled() {
|
||||
continue
|
||||
}
|
||||
spmHandler.Acquire()
|
||||
if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(t.updatedInput) || spmHandler.Cancelled() {
|
||||
spmHandler.Release()
|
||||
continue
|
||||
}
|
||||
request.options.RateLimitTake()
|
||||
select {
|
||||
case <-spmHandler.Done():
|
||||
spmHandler.Release()
|
||||
continue
|
||||
case spmHandler.ResultChan <- request.executeRequest(t.updatedInput, t.req, make(map[string]interface{}), false, wrappedCallback, 0):
|
||||
spmHandler.Release()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
for i := 0; i < currentWorkers; i++ {
|
||||
spawnWorker(ctx)
|
||||
}
|
||||
|
||||
// iterate payloads and make requests
|
||||
generator := request.newGenerator(false)
|
||||
for {
|
||||
@ -259,6 +301,13 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
|
||||
if err := spmHandler.Resize(input.Context(), request.options.Options.PayloadConcurrency); err != nil {
|
||||
return err
|
||||
}
|
||||
// if payload concurrency increased, add more workers
|
||||
if spmHandler.Size() > currentWorkers {
|
||||
for i := 0; i < spmHandler.Size()-currentWorkers; i++ {
|
||||
spawnWorker(ctx)
|
||||
}
|
||||
currentWorkers = spmHandler.Size()
|
||||
}
|
||||
}
|
||||
|
||||
// break if stop at first match is found or host is unresponsive
|
||||
@ -284,29 +333,21 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
|
||||
spmHandler.Cancel()
|
||||
return nil
|
||||
}
|
||||
spmHandler.Acquire()
|
||||
go func(httpRequest *generatedRequest) {
|
||||
defer spmHandler.Release()
|
||||
if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(updatedInput) || spmHandler.Cancelled() {
|
||||
return
|
||||
}
|
||||
// putting ratelimiter here prevents any unnecessary waiting if any
|
||||
request.options.RateLimitTake()
|
||||
|
||||
// after ratelimit take, check if we need to stop
|
||||
if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(updatedInput) || spmHandler.Cancelled() {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-spmHandler.Done():
|
||||
return
|
||||
case spmHandler.ResultChan <- request.executeRequest(input, httpRequest, make(map[string]interface{}), false, wrappedCallback, 0):
|
||||
return
|
||||
close(tasks)
|
||||
workersWg.Wait()
|
||||
spmHandler.Wait()
|
||||
if spmHandler.FoundFirstMatch() {
|
||||
return nil
|
||||
}
|
||||
return multierr.Combine(spmHandler.CombinedResults()...)
|
||||
case tasks <- task{req: generatedHttpRequest, updatedInput: updatedInput}:
|
||||
}
|
||||
}(generatedHttpRequest)
|
||||
request.options.Progress.IncrementRequests()
|
||||
}
|
||||
close(tasks)
|
||||
workersWg.Wait()
|
||||
spmHandler.Wait()
|
||||
if spmHandler.FoundFirstMatch() {
|
||||
// ignore any context cancellation and in-transit execution errors
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -257,3 +258,116 @@ func TestReqURLPattern(t *testing.T) {
|
||||
require.NotEmpty(t, finalEvent.Results[0].ReqURLPattern, "could not get req url pattern")
|
||||
require.Equal(t, `/{{rand_char("abc")}}/{{interactsh-url}}/123?query={{rand_int(1, 10)}}&data={{randstr}}`, finalEvent.Results[0].ReqURLPattern)
|
||||
}
|
||||
|
||||
// fakeHostErrorsCache implements hosterrorscache.CacheInterface minimally for tests
|
||||
type fakeHostErrorsCache struct{}
|
||||
|
||||
func (f *fakeHostErrorsCache) SetVerbose(bool) {}
|
||||
func (f *fakeHostErrorsCache) Close() {}
|
||||
func (f *fakeHostErrorsCache) Remove(*contextargs.Context) {}
|
||||
func (f *fakeHostErrorsCache) MarkFailed(string, *contextargs.Context, error) {}
|
||||
func (f *fakeHostErrorsCache) MarkFailedOrRemove(string, *contextargs.Context, error) {
|
||||
}
|
||||
|
||||
// Check always returns true to simulate an already unresponsive host
|
||||
func (f *fakeHostErrorsCache) Check(string, *contextargs.Context) bool { return true }
|
||||
|
||||
func TestExecuteParallelHTTP_StopAtFirstMatch(t *testing.T) {
|
||||
options := testutils.DefaultOptions
|
||||
testutils.Init(options)
|
||||
|
||||
// server that always matches
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = fmt.Fprintf(w, "match")
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
templateID := "parallel-stop-first"
|
||||
req := &Request{
|
||||
ID: templateID,
|
||||
Method: HTTPMethodTypeHolder{MethodType: HTTPGet},
|
||||
Path: []string{"{{BaseURL}}/p?x={{v}}"},
|
||||
Threads: 2,
|
||||
Payloads: map[string]interface{}{
|
||||
"v": []string{"1", "2"},
|
||||
},
|
||||
Operators: operators.Operators{
|
||||
Matchers: []*matchers.Matcher{{
|
||||
Part: "body",
|
||||
Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher},
|
||||
Words: []string{"match"},
|
||||
}},
|
||||
},
|
||||
StopAtFirstMatch: true,
|
||||
}
|
||||
|
||||
executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{
|
||||
ID: templateID,
|
||||
Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"},
|
||||
})
|
||||
err := req.Compile(executerOpts)
|
||||
require.NoError(t, err)
|
||||
|
||||
var matches int32
|
||||
metadata := make(output.InternalEvent)
|
||||
previous := make(output.InternalEvent)
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
|
||||
err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
|
||||
if event.OperatorsResult != nil && event.OperatorsResult.Matched {
|
||||
atomic.AddInt32(&matches, 1)
|
||||
}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&matches), "expected only first match to be processed")
|
||||
}
|
||||
|
||||
func TestExecuteParallelHTTP_SkipOnUnresponsiveFromCache(t *testing.T) {
|
||||
options := testutils.DefaultOptions
|
||||
testutils.Init(options)
|
||||
|
||||
// server that would match if reached
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = fmt.Fprintf(w, "match")
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
templateID := "parallel-skip-unresponsive"
|
||||
req := &Request{
|
||||
ID: templateID,
|
||||
Method: HTTPMethodTypeHolder{MethodType: HTTPGet},
|
||||
Path: []string{"{{BaseURL}}/p?x={{v}}"},
|
||||
Threads: 2,
|
||||
Payloads: map[string]interface{}{
|
||||
"v": []string{"1", "2"},
|
||||
},
|
||||
Operators: operators.Operators{
|
||||
Matchers: []*matchers.Matcher{{
|
||||
Part: "body",
|
||||
Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher},
|
||||
Words: []string{"match"},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{
|
||||
ID: templateID,
|
||||
Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"},
|
||||
})
|
||||
// inject fake host errors cache that forces skip
|
||||
executerOpts.HostErrorsCache = &fakeHostErrorsCache{}
|
||||
|
||||
err := req.Compile(executerOpts)
|
||||
require.NoError(t, err)
|
||||
|
||||
var matches int32
|
||||
metadata := make(output.InternalEvent)
|
||||
previous := make(output.InternalEvent)
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
|
||||
err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
|
||||
if event.OperatorsResult != nil && event.OperatorsResult.Matched {
|
||||
atomic.AddInt32(&matches, 1)
|
||||
}
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&matches), "expected no matches when host is marked unresponsive")
|
||||
}
|
||||
|
||||
@ -811,8 +811,11 @@ func beautifyJavascript(code string) string {
|
||||
}
|
||||
|
||||
func prettyPrint(templateId string, buff string) {
|
||||
if buff == "" {
|
||||
return
|
||||
}
|
||||
lines := strings.Split(buff, "\n")
|
||||
final := []string{}
|
||||
final := make([]string, 0, len(lines))
|
||||
for _, v := range lines {
|
||||
if v != "" {
|
||||
final = append(final, "\t"+v)
|
||||
|
||||
@ -192,11 +192,13 @@ func New(options *Options, db string, doNotDedupe bool) (Client, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if db != "" || len(client.trackers) > 0 || len(client.exporters) > 0 {
|
||||
storage, err := dedupe.New(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
client.dedupe = storage
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
|
||||
@ -12,7 +12,9 @@ type Cache struct {
|
||||
|
||||
// New returns a new templates cache
|
||||
func NewCache() *Cache {
|
||||
return &Cache{items: mapsutil.NewSyncLockMap[string, parsedTemplate]()}
|
||||
return &Cache{
|
||||
items: mapsutil.NewSyncLockMap[string, parsedTemplate](),
|
||||
}
|
||||
}
|
||||
|
||||
type parsedTemplate struct {
|
||||
@ -33,7 +35,31 @@ func (t *Cache) Has(template string) (*Template, []byte, error) {
|
||||
|
||||
// Store stores a template with data and error
|
||||
func (t *Cache) Store(id string, tpl *Template, raw []byte, err error) {
|
||||
_ = t.items.Set(id, parsedTemplate{template: tpl, raw: conversion.String(raw), err: err})
|
||||
entry := parsedTemplate{
|
||||
template: tpl,
|
||||
err: err,
|
||||
raw: conversion.String(raw),
|
||||
}
|
||||
_ = t.items.Set(id, entry)
|
||||
}
|
||||
|
||||
// StoreWithoutRaw stores a template without raw data for memory efficiency
|
||||
func (t *Cache) StoreWithoutRaw(id string, tpl *Template, err error) {
|
||||
entry := parsedTemplate{
|
||||
template: tpl,
|
||||
err: err,
|
||||
raw: "",
|
||||
}
|
||||
_ = t.items.Set(id, entry)
|
||||
}
|
||||
|
||||
// Get returns only the template without raw bytes
|
||||
func (t *Cache) Get(id string) (*Template, error) {
|
||||
value, ok := t.items.Get(id)
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
return value.template, value.err
|
||||
}
|
||||
|
||||
// Purge the cache
|
||||
|
||||
@ -404,7 +404,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err
|
||||
// for code protocol requests
|
||||
for _, request := range template.RequestsCode {
|
||||
// simple test to check if source is a file or a snippet
|
||||
if len(strings.Split(request.Source, "\n")) == 1 && fileutil.FileExists(request.Source) {
|
||||
if !strings.ContainsRune(request.Source, '\n') && fileutil.FileExists(request.Source) {
|
||||
if val, ok := loadFile(request.Source); ok {
|
||||
template.ImportedFiles = append(template.ImportedFiles, request.Source)
|
||||
request.Source = val
|
||||
@ -415,7 +415,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err
|
||||
// for javascript protocol code references
|
||||
for _, request := range template.RequestsJavascript {
|
||||
// simple test to check if source is a file or a snippet
|
||||
if len(strings.Split(request.Code, "\n")) == 1 && fileutil.FileExists(request.Code) {
|
||||
if !strings.ContainsRune(request.Code, '\n') && fileutil.FileExists(request.Code) {
|
||||
if val, ok := loadFile(request.Code); ok {
|
||||
template.ImportedFiles = append(template.ImportedFiles, request.Code)
|
||||
request.Code = val
|
||||
@ -442,7 +442,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err
|
||||
if req.Type() == types.CodeProtocol {
|
||||
request := req.(*code.Request)
|
||||
// simple test to check if source is a file or a snippet
|
||||
if len(strings.Split(request.Source, "\n")) == 1 && fileutil.FileExists(request.Source) {
|
||||
if !strings.ContainsRune(request.Source, '\n') && fileutil.FileExists(request.Source) {
|
||||
if val, ok := loadFile(request.Source); ok {
|
||||
template.ImportedFiles = append(template.ImportedFiles, request.Source)
|
||||
request.Source = val
|
||||
@ -456,7 +456,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err
|
||||
if req.Type() == types.JavascriptProtocol {
|
||||
request := req.(*javascript.Request)
|
||||
// simple test to check if source is a file or a snippet
|
||||
if len(strings.Split(request.Code, "\n")) == 1 && fileutil.FileExists(request.Code) {
|
||||
if !strings.ContainsRune(request.Code, '\n') && fileutil.FileExists(request.Code) {
|
||||
if val, ok := loadFile(request.Code); ok {
|
||||
template.ImportedFiles = append(template.ImportedFiles, request.Code)
|
||||
request.Code = val
|
||||
|
||||
@ -9,6 +9,32 @@ import (
|
||||
"gopkg.in/yaml.v2"
|
||||
)
|
||||
|
||||
func TestCachePoolZeroing(t *testing.T) {
|
||||
c := NewCache()
|
||||
|
||||
tpl := &Template{ID: "x"}
|
||||
raw := []byte("SOME BIG RAW")
|
||||
|
||||
c.Store("id1", tpl, raw, nil)
|
||||
gotTpl, gotErr := c.Get("id1")
|
||||
if gotErr != nil {
|
||||
t.Fatalf("unexpected err: %v", gotErr)
|
||||
}
|
||||
if gotTpl == nil || gotTpl.ID != "x" {
|
||||
t.Fatalf("unexpected tpl: %#v", gotTpl)
|
||||
}
|
||||
|
||||
// StoreWithoutRaw should not retain raw
|
||||
c.StoreWithoutRaw("id2", tpl, nil)
|
||||
gotTpl2, gotErr2 := c.Get("id2")
|
||||
if gotErr2 != nil {
|
||||
t.Fatalf("unexpected err: %v", gotErr2)
|
||||
}
|
||||
if gotTpl2 == nil || gotTpl2.ID != "x" {
|
||||
t.Fatalf("unexpected tpl2: %#v", gotTpl2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateStruct(t *testing.T) {
|
||||
templatePath := "./tests/match-1.yaml"
|
||||
bin, err := os.ReadFile(templatePath)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user