feat: fixed max-host-error blocking + progress mismatch + misc (#6193)

* feat: fixed max-host-error blocking wrong port for template with error

* feat: log total results with time taken at end of execution

* bugfix: skip non-executed requests with progress in flow protocol

* feat: fixed request calculation in http protocol for progress

* misc adjustments

---------

Co-authored-by: Ice3man <nizamulrana@gmail.com>
This commit is contained in:
Sandeep Singh 2025-05-07 17:22:15 +05:30 committed by GitHub
parent b9d0f2585f
commit 4801cc65ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 168 additions and 41 deletions

View File

@ -702,6 +702,7 @@ func (r *Runner) RunEnumeration() error {
}()
}
now := time.Now()
enumeration := false
var results *atomic.Bool
results, err = r.runStandardEnumeration(executorOpts, store, executorEngine)
@ -725,11 +726,17 @@ func (r *Runner) RunEnumeration() error {
}
r.fuzzFrequencyCache.Close()
r.progress.Stop()
timeTaken := time.Since(now)
// todo: error propagation without canonical straight error check is required by cloud?
// use safe dereferencing to avoid potential panics in case of previous unchecked errors
if v := ptrutil.Safe(results); !v.Load() {
gologger.Info().Msgf("No results found. Better luck next time!")
gologger.Info().Msgf("Scan completed in %s. No results found.", shortDur(timeTaken))
} else {
matchCount := r.output.ResultCount()
gologger.Info().Msgf("Scan completed in %s. %d matches found.", shortDur(timeTaken), matchCount)
}
// check if a passive scan was requested but no target was provided
if r.options.OfflineHTTP && len(r.options.Targets) == 0 && r.options.TargetsFilePath == "" {
return errors.Wrap(err, "missing required input (http response) to run passive templates")
@ -738,6 +745,24 @@ func (r *Runner) RunEnumeration() error {
return err
}
func shortDur(d time.Duration) string {
if d < time.Minute {
return d.String()
}
// Truncate to the nearest minute
d = d.Truncate(time.Minute)
s := d.String()
if strings.HasSuffix(s, "m0s") {
s = s[:len(s)-2]
}
if strings.HasSuffix(s, "h0m") {
s = s[:len(s)-2]
}
return s
}
func (r *Runner) isInputNonHTTP() bool {
var nonURLInput bool
r.inputProvider.Iterate(func(value *contextargs.MetaInput) bool {

View File

@ -110,6 +110,8 @@ func (e *Engine) executeTemplateSpray(ctx context.Context, templatesList []*temp
defer wp.Wait()
for _, template := range templatesList {
template := template
select {
case <-ctx.Done():
return results

View File

@ -65,3 +65,13 @@ func (mw *MultiWriter) RequestStatsLog(statusCode, response string) {
writer.RequestStatsLog(statusCode, response)
}
}
func (mw *MultiWriter) ResultCount() int {
count := 0
for _, writer := range mw.writers {
if count := writer.ResultCount(); count > 0 {
return count
}
}
return count
}

View File

@ -54,6 +54,8 @@ type Writer interface {
RequestStatsLog(statusCode, response string)
// WriteStoreDebugData writes the request/response debug data to file
WriteStoreDebugData(host, templateID, eventType string, data string)
// ResultCount returns the total number of results written
ResultCount() int
}
// StandardWriter is a writer writing output to file and screen for results.
@ -79,6 +81,8 @@ type StandardWriter struct {
// JSONLogRequestHook is a hook that can be used to log request/response
// when using custom server code with output
JSONLogRequestHook func(*JSONLogRequest)
resultCount atomic.Int32
}
var _ Writer = &StandardWriter{}
@ -287,6 +291,10 @@ func NewStandardWriter(options *types.Options) (*StandardWriter, error) {
return writer, nil
}
func (w *StandardWriter) ResultCount() int {
return int(w.resultCount.Load())
}
// Write writes the event to file and/or screen.
func (w *StandardWriter) Write(event *ResultEvent) error {
// Enrich the result event with extra metadata on the template-path and url.
@ -336,6 +344,7 @@ func (w *StandardWriter) Write(event *ResultEvent) error {
_, _ = w.outputFile.Write([]byte("\n"))
}
}
w.resultCount.Add(1)
return nil
}

View File

@ -49,3 +49,6 @@ func (tw *StatsOutputWriter) RequestStatsLog(statusCode, response string) {
tw.Tracker.TrackStatusCode(statusCode)
tw.Tracker.TrackWAFDetected(response)
}
func (tw *StatsOutputWriter) ResultCount() int {
return 0
}

View File

@ -120,9 +120,7 @@ func (p *StatsTicker) IncrementRequests() {
// SetRequests sets the counter by incrementing it with a delta
func (p *StatsTicker) SetRequests(count uint64) {
value, _ := p.stats.GetCounter("requests")
delta := count - value
p.stats.IncrementCounter("requests", int(delta))
p.stats.IncrementCounter("requests", int(count))
}
// IncrementMatched increments the matched counter by 1.

View File

@ -89,6 +89,9 @@ func (c *Cache) NormalizeCacheValue(value string) string {
u, err := url.ParseRequestURI(value)
if err != nil || u.Host == "" {
if strings.Contains(value, ":") {
return normalizedValue
}
u, err2 := url.ParseRequestURI("https://" + value)
if err2 != nil {
return normalizedValue
@ -236,14 +239,19 @@ func (c *Cache) GetKeyFromContext(ctx *contextargs.Context, err error) string {
// should be reflected in contextargs but it is not yet reflected in some cases
// and needs refactor of ScanContext + ContextArgs to achieve that
// i.e why we use real address from error if present
address := ctx.MetaInput.Address()
// get address override from error
var address string
// 1. the address carried inside the error (if the transport sets it)
if err != nil {
tmp := errkit.GetAttrValue(err, "address")
if tmp.Any() != nil {
address = tmp.String()
if v := errkit.GetAttrValue(err, "address"); v.Any() != nil {
address = v.String()
}
}
if address == "" {
address = ctx.MetaInput.Address()
}
finalValue := c.NormalizeCacheValue(address)
return finalValue
}

View File

@ -123,14 +123,6 @@ func (g *generatedRequest) URL() string {
return ""
}
// Total returns the total number of requests for the generator
func (r *requestGenerator) Total() int {
if r.payloadIterator != nil {
return len(r.request.Raw) * r.payloadIterator.Remaining()
}
return len(r.request.Path)
}
// Make creates a http request for the provided input.
// It returns ErrNoMoreRequests as error when all the requests have been exhausted.
func (r *requestGenerator) Make(ctx context.Context, input *contextargs.Context, reqData string, payloads, dynamicValues map[string]interface{}) (gr *generatedRequest, err error) {

View File

@ -501,7 +501,6 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error {
request.Threads = options.GetThreadsForNPayloadRequests(request.Requests(), request.Threads)
}
}
return nil
}
@ -517,24 +516,8 @@ func (request *Request) RebuildGenerator() error {
// Requests returns the total number of requests the YAML rule will perform
func (request *Request) Requests() int {
if request.generator != nil {
payloadRequests := request.generator.NewIterator().Total()
if len(request.Raw) > 0 {
payloadRequests = payloadRequests * len(request.Raw)
}
if len(request.Path) > 0 {
payloadRequests = payloadRequests * len(request.Path)
}
return payloadRequests
}
if len(request.Raw) > 0 {
requests := len(request.Raw)
if requests == 1 && request.RaceNumberRequests != 0 {
requests *= request.RaceNumberRequests
}
return requests
}
return len(request.Path)
generator := request.newGenerator(false)
return generator.Total()
}
const (

View File

@ -41,7 +41,6 @@ import (
"github.com/projectdiscovery/rawhttp"
convUtil "github.com/projectdiscovery/utils/conversion"
"github.com/projectdiscovery/utils/errkit"
errorutil "github.com/projectdiscovery/utils/errors"
httpUtils "github.com/projectdiscovery/utils/http"
"github.com/projectdiscovery/utils/reader"
sliceutil "github.com/projectdiscovery/utils/slice"
@ -484,7 +483,6 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
if err == types.ErrNoMoreRequests {
return true, nil
}
request.options.Progress.IncrementFailedRequestsBy(int64(generator.Total()))
return true, err
}
// ideally if http template used a custom port or hostname
@ -541,14 +539,19 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
if errors.Is(execReqErr, ErrMissingVars) {
return true, nil
}
if execReqErr != nil {
request.markHostError(updatedInput, execReqErr)
// if applicable mark the host as unresponsive
requestErr = errorutil.NewWithErr(execReqErr).Msgf("got err while executing %v", generatedHttpRequest.URL())
reqKitErr := errkit.FromError(execReqErr)
reqKitErr.Msgf("got err while executing %v", generatedHttpRequest.URL())
requestErr = reqKitErr
request.options.Progress.IncrementFailedRequestsBy(1)
} else {
request.options.Progress.IncrementRequests()
}
request.markHostError(updatedInput, execReqErr)
// If this was a match, and we want to stop at first match, skip all further requests.
shouldStopAtFirstMatch := generatedHttpRequest.original.options.Options.StopAtFirstMatch || generatedHttpRequest.original.options.StopAtFirstMatch || request.StopAtFirstMatch
@ -585,6 +588,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
requestErr = gotErr
}
if skip || gotErr != nil {
request.options.Progress.SetRequests(uint64(generator.Remaining() + 1))
break
}
}
@ -1212,7 +1216,7 @@ func (request *Request) newContext(input *contextargs.Context) context.Context {
// markHostError checks if the error is a unreponsive host error and marks it
func (request *Request) markHostError(input *contextargs.Context, err error) {
if request.options.HostErrorsCache != nil {
if request.options.HostErrorsCache != nil && err != nil {
request.options.HostErrorsCache.MarkFailedOrRemove(request.options.ProtocolType.String(), input, err)
}
}

View File

@ -135,3 +135,67 @@ func (r *requestGenerator) hasMarker(request string, mark flowMark) bool {
fo, hasOverrides := parseFlowAnnotations(request)
return hasOverrides && fo == mark
}
// Remaining returns the number of requests that are still left to be
// generated (and therefore to be sent) by this generator.
func (r *requestGenerator) Remaining() int {
var sequence []string
switch {
case len(r.request.Path) > 0:
sequence = r.request.Path
case len(r.request.Raw) > 0:
sequence = r.request.Raw
default:
return 0
}
remainingInCurrentPass := 0
for i := r.currentIndex; i < len(sequence); i++ {
if !r.hasMarker(sequence[i], Once) {
remainingInCurrentPass++
}
}
if r.payloadIterator == nil {
return remainingInCurrentPass
}
numRemainingPayloadSets := r.payloadIterator.Remaining()
totalValidInSequence := 0
for _, req := range sequence {
if !r.hasMarker(req, Once) {
totalValidInSequence++
}
}
// Total remaining = remaining in current pass + (remaining payload sets * requests per full pass)
return remainingInCurrentPass + numRemainingPayloadSets*totalValidInSequence
}
func (r *requestGenerator) Total() int {
var sequence []string
switch {
case len(r.request.Path) > 0:
sequence = r.request.Path
case len(r.request.Raw) > 0:
sequence = r.request.Raw
default:
return 0
}
applicableRequests := 0
additionalRequests := 0
for _, request := range sequence {
if !r.hasMarker(request, Once) {
applicableRequests++
} else {
additionalRequests++
}
}
if r.payloadIterator == nil {
return applicableRequests + additionalRequests
}
return (applicableRequests * r.payloadIterator.Total()) + additionalRequests
}

View File

@ -133,6 +133,10 @@ func (m *MockOutputWriter) Colorizer() aurora.Aurora {
return m.aurora
}
func (m *MockOutputWriter) ResultCount() int {
return 0
}
// Write writes the event to file and/or screen.
func (m *MockOutputWriter) Write(result *output.ResultEvent) error {
if m.WriteCallback != nil {

View File

@ -51,6 +51,8 @@ type FlowExecutor struct {
// these are keys whose values are meant to be flatten before executing
// a request ex: if dynamic extractor returns ["value"] it will be converted to "value"
flattenKeys []string
executed *mapsutil.SyncLockMap[string, struct{}]
}
// NewFlowExecutor creates a new flow executor from a list of requests
@ -98,6 +100,7 @@ func NewFlowExecutor(requests []protocols.Request, ctx *scan.ScanContext, option
results: results,
ctx: ctx,
program: program,
executed: mapsutil.NewSyncLockMap[string, struct{}](),
}
return f, nil
}
@ -243,6 +246,7 @@ func (f *FlowExecutor) ExecuteWithResults(ctx *scan.ScanContext) error {
// pass flow and execute the js vm and handle errors
_, err := runtime.RunProgram(f.program)
f.reconcileProgress()
if err != nil {
ctx.LogError(err)
return errorutil.NewWithErr(err).Msgf("failed to execute flow\n%v\n", f.options.Flow)
@ -256,6 +260,18 @@ func (f *FlowExecutor) ExecuteWithResults(ctx *scan.ScanContext) error {
return nil
}
func (f *FlowExecutor) reconcileProgress() {
for proto, list := range f.allProtocols {
for idx, req := range list {
key := requestKey(proto, req, strconv.Itoa(idx+1))
if _, seen := f.executed.Get(key); !seen {
// never executed → pretend it finished so that stats match
f.options.Progress.SetRequests(uint64(req.Requests()))
}
}
}
}
// GetRuntimeErrors returns all runtime errors (i.e errors from all protocol combined)
func (f *FlowExecutor) GetRuntimeErrors() error {
errs := []error{}

View File

@ -75,6 +75,8 @@ func (f *FlowExecutor) requestExecutor(runtime *goja.Runtime, reqMap mapsutil.Ma
}
}
err := req.ExecuteWithResults(inputItem, output.InternalEvent(f.options.GetTemplateCtx(f.ctx.Input.MetaInput).GetAll()), output.InternalEvent{}, f.protocolResultCallback(req, matcherStatus, opts))
// Mark the request as seen
_ = f.executed.Set(requestKey(opts.protoName, req, id), struct{}{})
if err != nil {
index := id
err = f.allErrs.Set(opts.protoName+":"+index, err)
@ -86,6 +88,13 @@ func (f *FlowExecutor) requestExecutor(runtime *goja.Runtime, reqMap mapsutil.Ma
return matcherStatus.Load()
}
func requestKey(proto string, req protocols.Request, id string) string {
if id == "" {
id = req.GetID()
}
return proto + ":" + id
}
// protocolResultCallback returns a callback that is executed
// after execution of each protocol request
func (f *FlowExecutor) protocolResultCallback(req protocols.Request, matcherStatus *atomic.Bool, _ *ProtoOptions) func(result *output.InternalWrappedEvent) {