diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 85fe0ea75..424d27116 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -702,6 +702,7 @@ func (r *Runner) RunEnumeration() error { }() } + now := time.Now() enumeration := false var results *atomic.Bool results, err = r.runStandardEnumeration(executorOpts, store, executorEngine) @@ -725,11 +726,17 @@ func (r *Runner) RunEnumeration() error { } r.fuzzFrequencyCache.Close() + r.progress.Stop() + timeTaken := time.Since(now) // todo: error propagation without canonical straight error check is required by cloud? // use safe dereferencing to avoid potential panics in case of previous unchecked errors if v := ptrutil.Safe(results); !v.Load() { - gologger.Info().Msgf("No results found. Better luck next time!") + gologger.Info().Msgf("Scan completed in %s. No results found.", shortDur(timeTaken)) + } else { + matchCount := r.output.ResultCount() + gologger.Info().Msgf("Scan completed in %s. %d matches found.", shortDur(timeTaken), matchCount) } + // check if a passive scan was requested but no target was provided if r.options.OfflineHTTP && len(r.options.Targets) == 0 && r.options.TargetsFilePath == "" { return errors.Wrap(err, "missing required input (http response) to run passive templates") @@ -738,6 +745,24 @@ func (r *Runner) RunEnumeration() error { return err } +func shortDur(d time.Duration) string { + if d < time.Minute { + return d.String() + } + + // Truncate to the nearest minute + d = d.Truncate(time.Minute) + s := d.String() + + if strings.HasSuffix(s, "m0s") { + s = s[:len(s)-2] + } + if strings.HasSuffix(s, "h0m") { + s = s[:len(s)-2] + } + return s +} + func (r *Runner) isInputNonHTTP() bool { var nonURLInput bool r.inputProvider.Iterate(func(value *contextargs.MetaInput) bool { diff --git a/pkg/core/execute_options.go b/pkg/core/execute_options.go index aa47bc44f..fae26b456 100644 --- a/pkg/core/execute_options.go +++ b/pkg/core/execute_options.go @@ -110,6 +110,8 @@ func (e *Engine) executeTemplateSpray(ctx context.Context, templatesList []*temp defer wp.Wait() for _, template := range templatesList { + template := template + select { case <-ctx.Done(): return results diff --git a/pkg/output/multi_writer.go b/pkg/output/multi_writer.go index 8ea729b4b..17b1c725a 100644 --- a/pkg/output/multi_writer.go +++ b/pkg/output/multi_writer.go @@ -65,3 +65,13 @@ func (mw *MultiWriter) RequestStatsLog(statusCode, response string) { writer.RequestStatsLog(statusCode, response) } } + +func (mw *MultiWriter) ResultCount() int { + count := 0 + for _, writer := range mw.writers { + if count := writer.ResultCount(); count > 0 { + return count + } + } + return count +} diff --git a/pkg/output/output.go b/pkg/output/output.go index 5c84bed30..e85774b83 100644 --- a/pkg/output/output.go +++ b/pkg/output/output.go @@ -54,6 +54,8 @@ type Writer interface { RequestStatsLog(statusCode, response string) // WriteStoreDebugData writes the request/response debug data to file WriteStoreDebugData(host, templateID, eventType string, data string) + // ResultCount returns the total number of results written + ResultCount() int } // StandardWriter is a writer writing output to file and screen for results. @@ -79,6 +81,8 @@ type StandardWriter struct { // JSONLogRequestHook is a hook that can be used to log request/response // when using custom server code with output JSONLogRequestHook func(*JSONLogRequest) + + resultCount atomic.Int32 } var _ Writer = &StandardWriter{} @@ -287,6 +291,10 @@ func NewStandardWriter(options *types.Options) (*StandardWriter, error) { return writer, nil } +func (w *StandardWriter) ResultCount() int { + return int(w.resultCount.Load()) +} + // Write writes the event to file and/or screen. func (w *StandardWriter) Write(event *ResultEvent) error { // Enrich the result event with extra metadata on the template-path and url. @@ -336,6 +344,7 @@ func (w *StandardWriter) Write(event *ResultEvent) error { _, _ = w.outputFile.Write([]byte("\n")) } } + w.resultCount.Add(1) return nil } diff --git a/pkg/output/output_stats.go b/pkg/output/output_stats.go index 7b0d509cd..68a234d85 100644 --- a/pkg/output/output_stats.go +++ b/pkg/output/output_stats.go @@ -49,3 +49,6 @@ func (tw *StatsOutputWriter) RequestStatsLog(statusCode, response string) { tw.Tracker.TrackStatusCode(statusCode) tw.Tracker.TrackWAFDetected(response) } +func (tw *StatsOutputWriter) ResultCount() int { + return 0 +} diff --git a/pkg/progress/progress.go b/pkg/progress/progress.go index 1ffb22cee..853fb103d 100644 --- a/pkg/progress/progress.go +++ b/pkg/progress/progress.go @@ -120,9 +120,7 @@ func (p *StatsTicker) IncrementRequests() { // SetRequests sets the counter by incrementing it with a delta func (p *StatsTicker) SetRequests(count uint64) { - value, _ := p.stats.GetCounter("requests") - delta := count - value - p.stats.IncrementCounter("requests", int(delta)) + p.stats.IncrementCounter("requests", int(count)) } // IncrementMatched increments the matched counter by 1. diff --git a/pkg/protocols/common/hosterrorscache/hosterrorscache.go b/pkg/protocols/common/hosterrorscache/hosterrorscache.go index 3943eef7e..3039dbdf0 100644 --- a/pkg/protocols/common/hosterrorscache/hosterrorscache.go +++ b/pkg/protocols/common/hosterrorscache/hosterrorscache.go @@ -89,6 +89,9 @@ func (c *Cache) NormalizeCacheValue(value string) string { u, err := url.ParseRequestURI(value) if err != nil || u.Host == "" { + if strings.Contains(value, ":") { + return normalizedValue + } u, err2 := url.ParseRequestURI("https://" + value) if err2 != nil { return normalizedValue @@ -236,14 +239,19 @@ func (c *Cache) GetKeyFromContext(ctx *contextargs.Context, err error) string { // should be reflected in contextargs but it is not yet reflected in some cases // and needs refactor of ScanContext + ContextArgs to achieve that // i.e why we use real address from error if present - address := ctx.MetaInput.Address() - // get address override from error + var address string + + // 1. the address carried inside the error (if the transport sets it) if err != nil { - tmp := errkit.GetAttrValue(err, "address") - if tmp.Any() != nil { - address = tmp.String() + if v := errkit.GetAttrValue(err, "address"); v.Any() != nil { + address = v.String() } } + + if address == "" { + address = ctx.MetaInput.Address() + } + finalValue := c.NormalizeCacheValue(address) return finalValue } diff --git a/pkg/protocols/http/build_request.go b/pkg/protocols/http/build_request.go index 3cde12d88..1cb9553c9 100644 --- a/pkg/protocols/http/build_request.go +++ b/pkg/protocols/http/build_request.go @@ -123,14 +123,6 @@ func (g *generatedRequest) URL() string { return "" } -// Total returns the total number of requests for the generator -func (r *requestGenerator) Total() int { - if r.payloadIterator != nil { - return len(r.request.Raw) * r.payloadIterator.Remaining() - } - return len(r.request.Path) -} - // Make creates a http request for the provided input. // It returns ErrNoMoreRequests as error when all the requests have been exhausted. func (r *requestGenerator) Make(ctx context.Context, input *contextargs.Context, reqData string, payloads, dynamicValues map[string]interface{}) (gr *generatedRequest, err error) { diff --git a/pkg/protocols/http/http.go b/pkg/protocols/http/http.go index 78710f79c..0b30a7408 100644 --- a/pkg/protocols/http/http.go +++ b/pkg/protocols/http/http.go @@ -501,7 +501,6 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error { request.Threads = options.GetThreadsForNPayloadRequests(request.Requests(), request.Threads) } } - return nil } @@ -517,24 +516,8 @@ func (request *Request) RebuildGenerator() error { // Requests returns the total number of requests the YAML rule will perform func (request *Request) Requests() int { - if request.generator != nil { - payloadRequests := request.generator.NewIterator().Total() - if len(request.Raw) > 0 { - payloadRequests = payloadRequests * len(request.Raw) - } - if len(request.Path) > 0 { - payloadRequests = payloadRequests * len(request.Path) - } - return payloadRequests - } - if len(request.Raw) > 0 { - requests := len(request.Raw) - if requests == 1 && request.RaceNumberRequests != 0 { - requests *= request.RaceNumberRequests - } - return requests - } - return len(request.Path) + generator := request.newGenerator(false) + return generator.Total() } const ( diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index 2cc32f5bf..6d8ad3e1d 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -41,7 +41,6 @@ import ( "github.com/projectdiscovery/rawhttp" convUtil "github.com/projectdiscovery/utils/conversion" "github.com/projectdiscovery/utils/errkit" - errorutil "github.com/projectdiscovery/utils/errors" httpUtils "github.com/projectdiscovery/utils/http" "github.com/projectdiscovery/utils/reader" sliceutil "github.com/projectdiscovery/utils/slice" @@ -484,7 +483,6 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa if err == types.ErrNoMoreRequests { return true, nil } - request.options.Progress.IncrementFailedRequestsBy(int64(generator.Total())) return true, err } // ideally if http template used a custom port or hostname @@ -541,14 +539,19 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa if errors.Is(execReqErr, ErrMissingVars) { return true, nil } + if execReqErr != nil { + request.markHostError(updatedInput, execReqErr) + // if applicable mark the host as unresponsive - requestErr = errorutil.NewWithErr(execReqErr).Msgf("got err while executing %v", generatedHttpRequest.URL()) + reqKitErr := errkit.FromError(execReqErr) + reqKitErr.Msgf("got err while executing %v", generatedHttpRequest.URL()) + + requestErr = reqKitErr request.options.Progress.IncrementFailedRequestsBy(1) } else { request.options.Progress.IncrementRequests() } - request.markHostError(updatedInput, execReqErr) // If this was a match, and we want to stop at first match, skip all further requests. shouldStopAtFirstMatch := generatedHttpRequest.original.options.Options.StopAtFirstMatch || generatedHttpRequest.original.options.StopAtFirstMatch || request.StopAtFirstMatch @@ -585,6 +588,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa requestErr = gotErr } if skip || gotErr != nil { + request.options.Progress.SetRequests(uint64(generator.Remaining() + 1)) break } } @@ -1212,7 +1216,7 @@ func (request *Request) newContext(input *contextargs.Context) context.Context { // markHostError checks if the error is a unreponsive host error and marks it func (request *Request) markHostError(input *contextargs.Context, err error) { - if request.options.HostErrorsCache != nil { + if request.options.HostErrorsCache != nil && err != nil { request.options.HostErrorsCache.MarkFailedOrRemove(request.options.ProtocolType.String(), input, err) } } diff --git a/pkg/protocols/http/request_generator.go b/pkg/protocols/http/request_generator.go index b15df1be9..4c4c701a8 100644 --- a/pkg/protocols/http/request_generator.go +++ b/pkg/protocols/http/request_generator.go @@ -135,3 +135,67 @@ func (r *requestGenerator) hasMarker(request string, mark flowMark) bool { fo, hasOverrides := parseFlowAnnotations(request) return hasOverrides && fo == mark } + +// Remaining returns the number of requests that are still left to be +// generated (and therefore to be sent) by this generator. +func (r *requestGenerator) Remaining() int { + var sequence []string + switch { + case len(r.request.Path) > 0: + sequence = r.request.Path + case len(r.request.Raw) > 0: + sequence = r.request.Raw + default: + return 0 + } + + remainingInCurrentPass := 0 + for i := r.currentIndex; i < len(sequence); i++ { + if !r.hasMarker(sequence[i], Once) { + remainingInCurrentPass++ + } + } + + if r.payloadIterator == nil { + return remainingInCurrentPass + } + + numRemainingPayloadSets := r.payloadIterator.Remaining() + totalValidInSequence := 0 + for _, req := range sequence { + if !r.hasMarker(req, Once) { + totalValidInSequence++ + } + } + + // Total remaining = remaining in current pass + (remaining payload sets * requests per full pass) + return remainingInCurrentPass + numRemainingPayloadSets*totalValidInSequence +} + +func (r *requestGenerator) Total() int { + var sequence []string + switch { + case len(r.request.Path) > 0: + sequence = r.request.Path + case len(r.request.Raw) > 0: + sequence = r.request.Raw + default: + return 0 + } + + applicableRequests := 0 + additionalRequests := 0 + for _, request := range sequence { + if !r.hasMarker(request, Once) { + applicableRequests++ + } else { + additionalRequests++ + } + } + + if r.payloadIterator == nil { + return applicableRequests + additionalRequests + } + + return (applicableRequests * r.payloadIterator.Total()) + additionalRequests +} diff --git a/pkg/testutils/testutils.go b/pkg/testutils/testutils.go index d59af2f7b..5f791c2c1 100644 --- a/pkg/testutils/testutils.go +++ b/pkg/testutils/testutils.go @@ -133,6 +133,10 @@ func (m *MockOutputWriter) Colorizer() aurora.Aurora { return m.aurora } +func (m *MockOutputWriter) ResultCount() int { + return 0 +} + // Write writes the event to file and/or screen. func (m *MockOutputWriter) Write(result *output.ResultEvent) error { if m.WriteCallback != nil { diff --git a/pkg/tmplexec/flow/flow_executor.go b/pkg/tmplexec/flow/flow_executor.go index 6e71cf840..62112e0f8 100644 --- a/pkg/tmplexec/flow/flow_executor.go +++ b/pkg/tmplexec/flow/flow_executor.go @@ -51,6 +51,8 @@ type FlowExecutor struct { // these are keys whose values are meant to be flatten before executing // a request ex: if dynamic extractor returns ["value"] it will be converted to "value" flattenKeys []string + + executed *mapsutil.SyncLockMap[string, struct{}] } // NewFlowExecutor creates a new flow executor from a list of requests @@ -98,6 +100,7 @@ func NewFlowExecutor(requests []protocols.Request, ctx *scan.ScanContext, option results: results, ctx: ctx, program: program, + executed: mapsutil.NewSyncLockMap[string, struct{}](), } return f, nil } @@ -243,6 +246,7 @@ func (f *FlowExecutor) ExecuteWithResults(ctx *scan.ScanContext) error { // pass flow and execute the js vm and handle errors _, err := runtime.RunProgram(f.program) + f.reconcileProgress() if err != nil { ctx.LogError(err) return errorutil.NewWithErr(err).Msgf("failed to execute flow\n%v\n", f.options.Flow) @@ -256,6 +260,18 @@ func (f *FlowExecutor) ExecuteWithResults(ctx *scan.ScanContext) error { return nil } +func (f *FlowExecutor) reconcileProgress() { + for proto, list := range f.allProtocols { + for idx, req := range list { + key := requestKey(proto, req, strconv.Itoa(idx+1)) + if _, seen := f.executed.Get(key); !seen { + // never executed → pretend it finished so that stats match + f.options.Progress.SetRequests(uint64(req.Requests())) + } + } + } +} + // GetRuntimeErrors returns all runtime errors (i.e errors from all protocol combined) func (f *FlowExecutor) GetRuntimeErrors() error { errs := []error{} diff --git a/pkg/tmplexec/flow/flow_internal.go b/pkg/tmplexec/flow/flow_internal.go index 92a852f9d..03cc29596 100644 --- a/pkg/tmplexec/flow/flow_internal.go +++ b/pkg/tmplexec/flow/flow_internal.go @@ -75,6 +75,8 @@ func (f *FlowExecutor) requestExecutor(runtime *goja.Runtime, reqMap mapsutil.Ma } } err := req.ExecuteWithResults(inputItem, output.InternalEvent(f.options.GetTemplateCtx(f.ctx.Input.MetaInput).GetAll()), output.InternalEvent{}, f.protocolResultCallback(req, matcherStatus, opts)) + // Mark the request as seen + _ = f.executed.Set(requestKey(opts.protoName, req, id), struct{}{}) if err != nil { index := id err = f.allErrs.Set(opts.protoName+":"+index, err) @@ -86,6 +88,13 @@ func (f *FlowExecutor) requestExecutor(runtime *goja.Runtime, reqMap mapsutil.Ma return matcherStatus.Load() } +func requestKey(proto string, req protocols.Request, id string) string { + if id == "" { + id = req.GetID() + } + return proto + ":" + id +} + // protocolResultCallback returns a callback that is executed // after execution of each protocol request func (f *FlowExecutor) protocolResultCallback(req protocols.Request, matcherStatus *atomic.Bool, _ *ProtoOptions) func(result *output.InternalWrappedEvent) {