feat: fixed max-host-error blocking + progress mismatch + misc (#6193)

* feat: fixed max-host-error blocking wrong port for template with error

* feat: log total results with time taken at end of execution

* bugfix: skip non-executed requests with progress in flow protocol

* feat: fixed request calculation in http protocol for progress

* misc adjustments

---------

Co-authored-by: Ice3man <nizamulrana@gmail.com>
This commit is contained in:
Sandeep Singh 2025-05-07 17:22:15 +05:30 committed by GitHub
parent b9d0f2585f
commit 4801cc65ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 168 additions and 41 deletions

View File

@ -702,6 +702,7 @@ func (r *Runner) RunEnumeration() error {
}() }()
} }
now := time.Now()
enumeration := false enumeration := false
var results *atomic.Bool var results *atomic.Bool
results, err = r.runStandardEnumeration(executorOpts, store, executorEngine) results, err = r.runStandardEnumeration(executorOpts, store, executorEngine)
@ -725,11 +726,17 @@ func (r *Runner) RunEnumeration() error {
} }
r.fuzzFrequencyCache.Close() r.fuzzFrequencyCache.Close()
r.progress.Stop()
timeTaken := time.Since(now)
// todo: error propagation without canonical straight error check is required by cloud? // todo: error propagation without canonical straight error check is required by cloud?
// use safe dereferencing to avoid potential panics in case of previous unchecked errors // use safe dereferencing to avoid potential panics in case of previous unchecked errors
if v := ptrutil.Safe(results); !v.Load() { if v := ptrutil.Safe(results); !v.Load() {
gologger.Info().Msgf("No results found. Better luck next time!") gologger.Info().Msgf("Scan completed in %s. No results found.", shortDur(timeTaken))
} else {
matchCount := r.output.ResultCount()
gologger.Info().Msgf("Scan completed in %s. %d matches found.", shortDur(timeTaken), matchCount)
} }
// check if a passive scan was requested but no target was provided // check if a passive scan was requested but no target was provided
if r.options.OfflineHTTP && len(r.options.Targets) == 0 && r.options.TargetsFilePath == "" { if r.options.OfflineHTTP && len(r.options.Targets) == 0 && r.options.TargetsFilePath == "" {
return errors.Wrap(err, "missing required input (http response) to run passive templates") return errors.Wrap(err, "missing required input (http response) to run passive templates")
@ -738,6 +745,24 @@ func (r *Runner) RunEnumeration() error {
return err return err
} }
func shortDur(d time.Duration) string {
if d < time.Minute {
return d.String()
}
// Truncate to the nearest minute
d = d.Truncate(time.Minute)
s := d.String()
if strings.HasSuffix(s, "m0s") {
s = s[:len(s)-2]
}
if strings.HasSuffix(s, "h0m") {
s = s[:len(s)-2]
}
return s
}
func (r *Runner) isInputNonHTTP() bool { func (r *Runner) isInputNonHTTP() bool {
var nonURLInput bool var nonURLInput bool
r.inputProvider.Iterate(func(value *contextargs.MetaInput) bool { r.inputProvider.Iterate(func(value *contextargs.MetaInput) bool {

View File

@ -110,6 +110,8 @@ func (e *Engine) executeTemplateSpray(ctx context.Context, templatesList []*temp
defer wp.Wait() defer wp.Wait()
for _, template := range templatesList { for _, template := range templatesList {
template := template
select { select {
case <-ctx.Done(): case <-ctx.Done():
return results return results

View File

@ -65,3 +65,13 @@ func (mw *MultiWriter) RequestStatsLog(statusCode, response string) {
writer.RequestStatsLog(statusCode, response) writer.RequestStatsLog(statusCode, response)
} }
} }
func (mw *MultiWriter) ResultCount() int {
count := 0
for _, writer := range mw.writers {
if count := writer.ResultCount(); count > 0 {
return count
}
}
return count
}

View File

@ -54,6 +54,8 @@ type Writer interface {
RequestStatsLog(statusCode, response string) RequestStatsLog(statusCode, response string)
// WriteStoreDebugData writes the request/response debug data to file // WriteStoreDebugData writes the request/response debug data to file
WriteStoreDebugData(host, templateID, eventType string, data string) WriteStoreDebugData(host, templateID, eventType string, data string)
// ResultCount returns the total number of results written
ResultCount() int
} }
// StandardWriter is a writer writing output to file and screen for results. // StandardWriter is a writer writing output to file and screen for results.
@ -79,6 +81,8 @@ type StandardWriter struct {
// JSONLogRequestHook is a hook that can be used to log request/response // JSONLogRequestHook is a hook that can be used to log request/response
// when using custom server code with output // when using custom server code with output
JSONLogRequestHook func(*JSONLogRequest) JSONLogRequestHook func(*JSONLogRequest)
resultCount atomic.Int32
} }
var _ Writer = &StandardWriter{} var _ Writer = &StandardWriter{}
@ -287,6 +291,10 @@ func NewStandardWriter(options *types.Options) (*StandardWriter, error) {
return writer, nil return writer, nil
} }
func (w *StandardWriter) ResultCount() int {
return int(w.resultCount.Load())
}
// Write writes the event to file and/or screen. // Write writes the event to file and/or screen.
func (w *StandardWriter) Write(event *ResultEvent) error { func (w *StandardWriter) Write(event *ResultEvent) error {
// Enrich the result event with extra metadata on the template-path and url. // Enrich the result event with extra metadata on the template-path and url.
@ -336,6 +344,7 @@ func (w *StandardWriter) Write(event *ResultEvent) error {
_, _ = w.outputFile.Write([]byte("\n")) _, _ = w.outputFile.Write([]byte("\n"))
} }
} }
w.resultCount.Add(1)
return nil return nil
} }

View File

@ -49,3 +49,6 @@ func (tw *StatsOutputWriter) RequestStatsLog(statusCode, response string) {
tw.Tracker.TrackStatusCode(statusCode) tw.Tracker.TrackStatusCode(statusCode)
tw.Tracker.TrackWAFDetected(response) tw.Tracker.TrackWAFDetected(response)
} }
func (tw *StatsOutputWriter) ResultCount() int {
return 0
}

View File

@ -120,9 +120,7 @@ func (p *StatsTicker) IncrementRequests() {
// SetRequests sets the counter by incrementing it with a delta // SetRequests sets the counter by incrementing it with a delta
func (p *StatsTicker) SetRequests(count uint64) { func (p *StatsTicker) SetRequests(count uint64) {
value, _ := p.stats.GetCounter("requests") p.stats.IncrementCounter("requests", int(count))
delta := count - value
p.stats.IncrementCounter("requests", int(delta))
} }
// IncrementMatched increments the matched counter by 1. // IncrementMatched increments the matched counter by 1.

View File

@ -89,6 +89,9 @@ func (c *Cache) NormalizeCacheValue(value string) string {
u, err := url.ParseRequestURI(value) u, err := url.ParseRequestURI(value)
if err != nil || u.Host == "" { if err != nil || u.Host == "" {
if strings.Contains(value, ":") {
return normalizedValue
}
u, err2 := url.ParseRequestURI("https://" + value) u, err2 := url.ParseRequestURI("https://" + value)
if err2 != nil { if err2 != nil {
return normalizedValue return normalizedValue
@ -236,14 +239,19 @@ func (c *Cache) GetKeyFromContext(ctx *contextargs.Context, err error) string {
// should be reflected in contextargs but it is not yet reflected in some cases // should be reflected in contextargs but it is not yet reflected in some cases
// and needs refactor of ScanContext + ContextArgs to achieve that // and needs refactor of ScanContext + ContextArgs to achieve that
// i.e why we use real address from error if present // i.e why we use real address from error if present
address := ctx.MetaInput.Address() var address string
// get address override from error
// 1. the address carried inside the error (if the transport sets it)
if err != nil { if err != nil {
tmp := errkit.GetAttrValue(err, "address") if v := errkit.GetAttrValue(err, "address"); v.Any() != nil {
if tmp.Any() != nil { address = v.String()
address = tmp.String()
} }
} }
if address == "" {
address = ctx.MetaInput.Address()
}
finalValue := c.NormalizeCacheValue(address) finalValue := c.NormalizeCacheValue(address)
return finalValue return finalValue
} }

View File

@ -123,14 +123,6 @@ func (g *generatedRequest) URL() string {
return "" return ""
} }
// Total returns the total number of requests for the generator
func (r *requestGenerator) Total() int {
if r.payloadIterator != nil {
return len(r.request.Raw) * r.payloadIterator.Remaining()
}
return len(r.request.Path)
}
// Make creates a http request for the provided input. // Make creates a http request for the provided input.
// It returns ErrNoMoreRequests as error when all the requests have been exhausted. // It returns ErrNoMoreRequests as error when all the requests have been exhausted.
func (r *requestGenerator) Make(ctx context.Context, input *contextargs.Context, reqData string, payloads, dynamicValues map[string]interface{}) (gr *generatedRequest, err error) { func (r *requestGenerator) Make(ctx context.Context, input *contextargs.Context, reqData string, payloads, dynamicValues map[string]interface{}) (gr *generatedRequest, err error) {

View File

@ -501,7 +501,6 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error {
request.Threads = options.GetThreadsForNPayloadRequests(request.Requests(), request.Threads) request.Threads = options.GetThreadsForNPayloadRequests(request.Requests(), request.Threads)
} }
} }
return nil return nil
} }
@ -517,24 +516,8 @@ func (request *Request) RebuildGenerator() error {
// Requests returns the total number of requests the YAML rule will perform // Requests returns the total number of requests the YAML rule will perform
func (request *Request) Requests() int { func (request *Request) Requests() int {
if request.generator != nil { generator := request.newGenerator(false)
payloadRequests := request.generator.NewIterator().Total() return generator.Total()
if len(request.Raw) > 0 {
payloadRequests = payloadRequests * len(request.Raw)
}
if len(request.Path) > 0 {
payloadRequests = payloadRequests * len(request.Path)
}
return payloadRequests
}
if len(request.Raw) > 0 {
requests := len(request.Raw)
if requests == 1 && request.RaceNumberRequests != 0 {
requests *= request.RaceNumberRequests
}
return requests
}
return len(request.Path)
} }
const ( const (

View File

@ -41,7 +41,6 @@ import (
"github.com/projectdiscovery/rawhttp" "github.com/projectdiscovery/rawhttp"
convUtil "github.com/projectdiscovery/utils/conversion" convUtil "github.com/projectdiscovery/utils/conversion"
"github.com/projectdiscovery/utils/errkit" "github.com/projectdiscovery/utils/errkit"
errorutil "github.com/projectdiscovery/utils/errors"
httpUtils "github.com/projectdiscovery/utils/http" httpUtils "github.com/projectdiscovery/utils/http"
"github.com/projectdiscovery/utils/reader" "github.com/projectdiscovery/utils/reader"
sliceutil "github.com/projectdiscovery/utils/slice" sliceutil "github.com/projectdiscovery/utils/slice"
@ -484,7 +483,6 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
if err == types.ErrNoMoreRequests { if err == types.ErrNoMoreRequests {
return true, nil return true, nil
} }
request.options.Progress.IncrementFailedRequestsBy(int64(generator.Total()))
return true, err return true, err
} }
// ideally if http template used a custom port or hostname // ideally if http template used a custom port or hostname
@ -541,14 +539,19 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
if errors.Is(execReqErr, ErrMissingVars) { if errors.Is(execReqErr, ErrMissingVars) {
return true, nil return true, nil
} }
if execReqErr != nil { if execReqErr != nil {
request.markHostError(updatedInput, execReqErr)
// if applicable mark the host as unresponsive // if applicable mark the host as unresponsive
requestErr = errorutil.NewWithErr(execReqErr).Msgf("got err while executing %v", generatedHttpRequest.URL()) reqKitErr := errkit.FromError(execReqErr)
reqKitErr.Msgf("got err while executing %v", generatedHttpRequest.URL())
requestErr = reqKitErr
request.options.Progress.IncrementFailedRequestsBy(1) request.options.Progress.IncrementFailedRequestsBy(1)
} else { } else {
request.options.Progress.IncrementRequests() request.options.Progress.IncrementRequests()
} }
request.markHostError(updatedInput, execReqErr)
// If this was a match, and we want to stop at first match, skip all further requests. // If this was a match, and we want to stop at first match, skip all further requests.
shouldStopAtFirstMatch := generatedHttpRequest.original.options.Options.StopAtFirstMatch || generatedHttpRequest.original.options.StopAtFirstMatch || request.StopAtFirstMatch shouldStopAtFirstMatch := generatedHttpRequest.original.options.Options.StopAtFirstMatch || generatedHttpRequest.original.options.StopAtFirstMatch || request.StopAtFirstMatch
@ -585,6 +588,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
requestErr = gotErr requestErr = gotErr
} }
if skip || gotErr != nil { if skip || gotErr != nil {
request.options.Progress.SetRequests(uint64(generator.Remaining() + 1))
break break
} }
} }
@ -1212,7 +1216,7 @@ func (request *Request) newContext(input *contextargs.Context) context.Context {
// markHostError checks if the error is a unreponsive host error and marks it // markHostError checks if the error is a unreponsive host error and marks it
func (request *Request) markHostError(input *contextargs.Context, err error) { func (request *Request) markHostError(input *contextargs.Context, err error) {
if request.options.HostErrorsCache != nil { if request.options.HostErrorsCache != nil && err != nil {
request.options.HostErrorsCache.MarkFailedOrRemove(request.options.ProtocolType.String(), input, err) request.options.HostErrorsCache.MarkFailedOrRemove(request.options.ProtocolType.String(), input, err)
} }
} }

View File

@ -135,3 +135,67 @@ func (r *requestGenerator) hasMarker(request string, mark flowMark) bool {
fo, hasOverrides := parseFlowAnnotations(request) fo, hasOverrides := parseFlowAnnotations(request)
return hasOverrides && fo == mark return hasOverrides && fo == mark
} }
// Remaining returns the number of requests that are still left to be
// generated (and therefore to be sent) by this generator.
func (r *requestGenerator) Remaining() int {
var sequence []string
switch {
case len(r.request.Path) > 0:
sequence = r.request.Path
case len(r.request.Raw) > 0:
sequence = r.request.Raw
default:
return 0
}
remainingInCurrentPass := 0
for i := r.currentIndex; i < len(sequence); i++ {
if !r.hasMarker(sequence[i], Once) {
remainingInCurrentPass++
}
}
if r.payloadIterator == nil {
return remainingInCurrentPass
}
numRemainingPayloadSets := r.payloadIterator.Remaining()
totalValidInSequence := 0
for _, req := range sequence {
if !r.hasMarker(req, Once) {
totalValidInSequence++
}
}
// Total remaining = remaining in current pass + (remaining payload sets * requests per full pass)
return remainingInCurrentPass + numRemainingPayloadSets*totalValidInSequence
}
func (r *requestGenerator) Total() int {
var sequence []string
switch {
case len(r.request.Path) > 0:
sequence = r.request.Path
case len(r.request.Raw) > 0:
sequence = r.request.Raw
default:
return 0
}
applicableRequests := 0
additionalRequests := 0
for _, request := range sequence {
if !r.hasMarker(request, Once) {
applicableRequests++
} else {
additionalRequests++
}
}
if r.payloadIterator == nil {
return applicableRequests + additionalRequests
}
return (applicableRequests * r.payloadIterator.Total()) + additionalRequests
}

View File

@ -133,6 +133,10 @@ func (m *MockOutputWriter) Colorizer() aurora.Aurora {
return m.aurora return m.aurora
} }
func (m *MockOutputWriter) ResultCount() int {
return 0
}
// Write writes the event to file and/or screen. // Write writes the event to file and/or screen.
func (m *MockOutputWriter) Write(result *output.ResultEvent) error { func (m *MockOutputWriter) Write(result *output.ResultEvent) error {
if m.WriteCallback != nil { if m.WriteCallback != nil {

View File

@ -51,6 +51,8 @@ type FlowExecutor struct {
// these are keys whose values are meant to be flatten before executing // these are keys whose values are meant to be flatten before executing
// a request ex: if dynamic extractor returns ["value"] it will be converted to "value" // a request ex: if dynamic extractor returns ["value"] it will be converted to "value"
flattenKeys []string flattenKeys []string
executed *mapsutil.SyncLockMap[string, struct{}]
} }
// NewFlowExecutor creates a new flow executor from a list of requests // NewFlowExecutor creates a new flow executor from a list of requests
@ -98,6 +100,7 @@ func NewFlowExecutor(requests []protocols.Request, ctx *scan.ScanContext, option
results: results, results: results,
ctx: ctx, ctx: ctx,
program: program, program: program,
executed: mapsutil.NewSyncLockMap[string, struct{}](),
} }
return f, nil return f, nil
} }
@ -243,6 +246,7 @@ func (f *FlowExecutor) ExecuteWithResults(ctx *scan.ScanContext) error {
// pass flow and execute the js vm and handle errors // pass flow and execute the js vm and handle errors
_, err := runtime.RunProgram(f.program) _, err := runtime.RunProgram(f.program)
f.reconcileProgress()
if err != nil { if err != nil {
ctx.LogError(err) ctx.LogError(err)
return errorutil.NewWithErr(err).Msgf("failed to execute flow\n%v\n", f.options.Flow) return errorutil.NewWithErr(err).Msgf("failed to execute flow\n%v\n", f.options.Flow)
@ -256,6 +260,18 @@ func (f *FlowExecutor) ExecuteWithResults(ctx *scan.ScanContext) error {
return nil return nil
} }
func (f *FlowExecutor) reconcileProgress() {
for proto, list := range f.allProtocols {
for idx, req := range list {
key := requestKey(proto, req, strconv.Itoa(idx+1))
if _, seen := f.executed.Get(key); !seen {
// never executed → pretend it finished so that stats match
f.options.Progress.SetRequests(uint64(req.Requests()))
}
}
}
}
// GetRuntimeErrors returns all runtime errors (i.e errors from all protocol combined) // GetRuntimeErrors returns all runtime errors (i.e errors from all protocol combined)
func (f *FlowExecutor) GetRuntimeErrors() error { func (f *FlowExecutor) GetRuntimeErrors() error {
errs := []error{} errs := []error{}

View File

@ -75,6 +75,8 @@ func (f *FlowExecutor) requestExecutor(runtime *goja.Runtime, reqMap mapsutil.Ma
} }
} }
err := req.ExecuteWithResults(inputItem, output.InternalEvent(f.options.GetTemplateCtx(f.ctx.Input.MetaInput).GetAll()), output.InternalEvent{}, f.protocolResultCallback(req, matcherStatus, opts)) err := req.ExecuteWithResults(inputItem, output.InternalEvent(f.options.GetTemplateCtx(f.ctx.Input.MetaInput).GetAll()), output.InternalEvent{}, f.protocolResultCallback(req, matcherStatus, opts))
// Mark the request as seen
_ = f.executed.Set(requestKey(opts.protoName, req, id), struct{}{})
if err != nil { if err != nil {
index := id index := id
err = f.allErrs.Set(opts.protoName+":"+index, err) err = f.allErrs.Set(opts.protoName+":"+index, err)
@ -86,6 +88,13 @@ func (f *FlowExecutor) requestExecutor(runtime *goja.Runtime, reqMap mapsutil.Ma
return matcherStatus.Load() return matcherStatus.Load()
} }
func requestKey(proto string, req protocols.Request, id string) string {
if id == "" {
id = req.GetID()
}
return proto + ":" + id
}
// protocolResultCallback returns a callback that is executed // protocolResultCallback returns a callback that is executed
// after execution of each protocol request // after execution of each protocol request
func (f *FlowExecutor) protocolResultCallback(req protocols.Request, matcherStatus *atomic.Bool, _ *ProtoOptions) func(result *output.InternalWrappedEvent) { func (f *FlowExecutor) protocolResultCallback(req protocols.Request, matcherStatus *atomic.Bool, _ *ProtoOptions) func(result *output.InternalWrappedEvent) {