implicit thread count when not specified in payloads + threads support in dns,network (#4715)

* default threads + add threads support in dns payloads

* add threads support in network protocol

* add optional callback to override threadSetter

* fix broken fuzz integration tests
This commit is contained in:
Tarun Koyalwar 2024-02-02 02:05:30 +05:30 committed by GitHub
parent e4298a5ae1
commit ead58f4ab9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 94 additions and 13 deletions

View File

@ -75,6 +75,12 @@ type Request struct {
// of payloads is provided, or optionally a single file can also
// be provided as payload which will be read on run-time.
Payloads map[string]interface{} `yaml:"payloads,omitempty" json:"payloads,omitempty" jsonschema:"title=payloads for the network request,description=Payloads contains any payloads for the current request"`
// description: |
// Threads to use when sending iterating over payloads
// examples:
// - name: Send requests using 10 concurrent threads
// value: 10
Threads int `yaml:"threads,omitempty" json:"threads,omitempty" jsonschema:"title=threads for sending requests,description=Threads specifies number of threads to use sending requests. This enables Connection Pooling"`
generator *generators.PayloadGenerator
CompiledOperators *operators.Operators `yaml:"-"`
@ -176,6 +182,8 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error {
if err != nil {
return errors.Wrap(err, "could not parse payloads")
}
// default to 20 threads for payload requests
request.Threads = options.GetThreadsForNPayloadRequests(request.Requests(), request.Threads)
}
return nil
}

View File

@ -5,9 +5,12 @@ import (
"fmt"
"net/url"
"strings"
"sync"
"github.com/miekg/dns"
"github.com/pkg/errors"
"github.com/remeh/sizedwaitgroup"
"go.uber.org/multierr"
"golang.org/x/exp/maps"
"github.com/projectdiscovery/gologger"
@ -61,6 +64,9 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata,
if request.generator != nil {
iterator := request.generator.NewIterator()
swg := sizedwaitgroup.New(request.Threads)
var multiErr error
m := &sync.Mutex{}
for {
value, ok := iterator.Value()
@ -68,9 +74,19 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata,
break
}
value = generators.MergeMaps(vars, value)
if err := request.execute(input, domain, metadata, previous, value, callback); err != nil {
return err
swg.Add()
go func(newVars map[string]interface{}) {
defer swg.Done()
if err := request.execute(input, domain, metadata, previous, newVars, callback); err != nil {
m.Lock()
multiErr = multierr.Append(multiErr, err)
m.Unlock()
}
}(value)
}
swg.Wait()
if multiErr != nil {
return multiErr
}
} else {
value := maps.Clone(vars)

View File

@ -387,6 +387,11 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error {
}
}
}
if len(request.Payloads) > 0 {
// if we have payloads, adjust threads if none specified
request.Threads = options.GetThreadsForNPayloadRequests(request.Requests(), request.Threads)
}
return nil
}

View File

@ -342,16 +342,16 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
return request.executeRaceRequest(input, dynamicValues, callback)
}
// verify if parallel elaboration was requested
if request.Threads > 0 {
return request.executeParallelHTTP(input, dynamicValues, callback)
}
// verify if fuzz elaboration was requested
if len(request.Fuzzing) > 0 {
return request.executeFuzzingRule(input, dynamicValues, callback)
}
// verify if parallel elaboration was requested
if request.Threads > 0 {
return request.executeParallelHTTP(input, dynamicValues, callback)
}
generator := request.newGenerator(false)
var gotDynamicValues map[string][]string

View File

@ -107,6 +107,8 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error {
if err != nil {
return errors.Wrap(err, "could not parse payloads")
}
// default to 20 threads for payload requests
request.Threads = options.GetThreadsForNPayloadRequests(request.Requests(), request.Threads)
}
if len(request.Matchers) > 0 || len(request.Extractors) > 0 {

View File

@ -45,6 +45,15 @@ type Request struct {
// of payloads is provided, or optionally a single file can also
// be provided as payload which will be read on run-time.
Payloads map[string]interface{} `yaml:"payloads,omitempty" json:"payloads,omitempty" jsonschema:"title=payloads for the network request,description=Payloads contains any payloads for the current request"`
// description: |
// Threads specifies number of threads to use sending requests. This enables Connection Pooling.
//
// Connection: Close attribute must not be used in request while using threads flag, otherwise
// pooling will fail and engine will continue to close connections after requests.
// examples:
// - name: Send requests using 10 concurrent threads
// value: 10
Threads int `yaml:"threads,omitempty" json:"threads,omitempty" jsonschema:"title=threads for sending requests,description=Threads specifies number of threads to use sending requests. This enables Connection Pooling"`
// description: |
// Inputs contains inputs for the network socket
@ -219,6 +228,8 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error {
if err != nil {
return errors.Wrap(err, "could not parse payloads")
}
// if we have payloads, adjust threads if none specified
request.Threads = options.GetThreadsForNPayloadRequests(request.Requests(), request.Threads)
}
// Create a client for the class

View File

@ -8,9 +8,11 @@ import (
"net/url"
"os"
"strings"
"sync"
"time"
"github.com/pkg/errors"
"github.com/remeh/sizedwaitgroup"
"go.uber.org/multierr"
"golang.org/x/exp/maps"
@ -174,6 +176,9 @@ func (request *Request) executeAddress(variables map[string]interface{}, actualA
if request.generator != nil {
iterator := request.generator.NewIterator()
var multiErr error
m := &sync.Mutex{}
swg := sizedwaitgroup.New(request.Threads)
for {
value, ok := iterator.Value()
@ -181,9 +186,19 @@ func (request *Request) executeAddress(variables map[string]interface{}, actualA
break
}
value = generators.MergeMaps(value, payloads)
if err := request.executeRequestWithPayloads(variables, actualAddress, address, input, shouldUseTLS, value, previous, callback); err != nil {
return err
swg.Add()
go func(vars map[string]interface{}) {
defer swg.Done()
if err := request.executeRequestWithPayloads(variables, actualAddress, address, input, shouldUseTLS, vars, previous, callback); err != nil {
m.Lock()
multiErr = multierr.Append(multiErr, err)
m.Unlock()
}
}(value)
}
swg.Wait()
if multiErr != nil {
return multiErr
}
} else {
value := maps.Clone(payloads)

View File

@ -32,7 +32,12 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/types"
)
var MaxTemplateFileSizeForEncoding = 1024 * 1024
// Optional Callback to update Thread count in payloads across all requests
type PayloadThreadSetterCallback func(opts *ExecutorOptions, totalRequests, currentThreads int) int
var (
MaxTemplateFileSizeForEncoding = 1024 * 1024
)
// Executer is an interface implemented any protocol based request executer.
type Executer interface {
@ -107,6 +112,25 @@ type ExecutorOptions struct {
// JsCompiler is abstracted javascript compiler which adds node modules and provides execution
// environment for javascript templates
JsCompiler *compiler.Compiler
// Optional Callback function to update Thread count in payloads across all protocols
// based on given logic. by default nuclei reverts to using value of `-c` when threads count
// is not specified or is 0 in template
OverrideThreadsCount PayloadThreadSetterCallback
}
// GetThreadsForPayloadRequests returns the number of threads to use as default for
// given max-request of payloads
func (e *ExecutorOptions) GetThreadsForNPayloadRequests(totalRequests int, currentThreads int) int {
if e.OverrideThreadsCount != nil {
return e.OverrideThreadsCount(e, totalRequests, currentThreads)
}
if currentThreads != 0 {
return currentThreads
}
if totalRequests <= 0 {
return e.Options.TemplateThreads
}
return totalRequests
}
// CreateTemplateCtxStore creates template context store (which contains templateCtx for every scan)