Ice3man 4170e1cbb8
more goroutine leak fixes to nuclei (#5188)
* more goroutine leak fixes to nuclei

* run only dns templates for test

* updated httpx to dev

* dep update

---------

Co-authored-by: sandeep <8293321+ehsandeep@users.noreply.github.com>
2024-05-20 00:48:40 +05:30

214 lines
4.8 KiB
Go

package httputils
import (
"context"
"sync"
syncutil "github.com/projectdiscovery/utils/sync"
"golang.org/x/exp/maps"
)
// WorkPoolType is the type of work pool to use
type WorkPoolType uint
const (
// Blocking blocks addition of new work when the pool is full
Blocking WorkPoolType = iota
// NonBlocking does not block addition of new work when the pool is full
NonBlocking
)
// StopAtFirstMatchHandler is a handler that executes
// request and stops on first match
type StopAtFirstMatchHandler[T comparable] struct {
once sync.Once
// Result Channel
ResultChan chan T
// work pool and its type
poolType WorkPoolType
sgPool *syncutil.AdaptiveWaitGroup
wgPool *sync.WaitGroup
// internal / unexported
ctx context.Context
cancel context.CancelFunc
internalWg *sync.WaitGroup
results map[T]struct{}
onResult func(T)
stopEnabled bool
maxResults int
}
// NewBlockingSPMHandler creates a new stop at first match handler
func NewBlockingSPMHandler[T comparable](ctx context.Context, size int, maxResults int, spm bool) *StopAtFirstMatchHandler[T] {
ctx1, cancel := context.WithCancel(ctx)
awg, _ := syncutil.New(syncutil.WithSize(size))
s := &StopAtFirstMatchHandler[T]{
ResultChan: make(chan T, 1),
poolType: Blocking,
sgPool: awg,
internalWg: &sync.WaitGroup{},
ctx: ctx1,
cancel: cancel,
stopEnabled: spm,
results: make(map[T]struct{}),
maxResults: maxResults,
}
s.internalWg.Add(1)
go s.run(ctx)
return s
}
// NewNonBlockingSPMHandler creates a new stop at first match handler
func NewNonBlockingSPMHandler[T comparable](ctx context.Context, maxResults int, spm bool) *StopAtFirstMatchHandler[T] {
ctx1, cancel := context.WithCancel(ctx)
s := &StopAtFirstMatchHandler[T]{
ResultChan: make(chan T, 1),
poolType: NonBlocking,
wgPool: &sync.WaitGroup{},
internalWg: &sync.WaitGroup{},
ctx: ctx1,
cancel: cancel,
stopEnabled: spm,
results: make(map[T]struct{}),
maxResults: maxResults,
}
s.internalWg.Add(1)
go s.run(ctx)
return s
}
// Trigger triggers the stop at first match handler and stops the execution of
// existing requests
func (h *StopAtFirstMatchHandler[T]) Trigger() {
if h.stopEnabled {
h.cancel()
}
}
// Cancel cancels spm context
func (h *StopAtFirstMatchHandler[T]) Cancel() {
h.cancel()
}
// SetOnResult callback
// this is not thread safe
func (h *StopAtFirstMatchHandler[T]) SetOnResultCallback(fn func(T)) {
if h.onResult != nil {
tmp := h.onResult
h.onResult = func(t T) {
tmp(t)
fn(t)
}
} else {
h.onResult = fn
}
}
// MatchCallback is called when a match is found
// input fn should be the callback that is intended to be called
// if stop at first is enabled and other conditions are met
// if it does not meet above conditions, use of this function is discouraged
func (h *StopAtFirstMatchHandler[T]) MatchCallback(fn func()) {
if !h.stopEnabled {
fn()
return
}
h.once.Do(fn)
}
// run runs the internal handler
func (h *StopAtFirstMatchHandler[T]) run(ctx context.Context) {
defer h.internalWg.Done()
for {
select {
case <-ctx.Done():
return
case val, ok := <-h.ResultChan:
if !ok {
return
}
if h.onResult != nil {
h.onResult(val)
}
if len(h.results) >= h.maxResults {
// skip or do not store the result
continue
}
h.results[val] = struct{}{}
}
}
}
// Done returns a channel with the context done signal when stop at first match is detected
func (h *StopAtFirstMatchHandler[T]) Done() <-chan struct{} {
return h.ctx.Done()
}
// Cancelled returns true if the context is cancelled
func (h *StopAtFirstMatchHandler[T]) Cancelled() bool {
return h.ctx.Err() != nil
}
// FoundFirstMatch returns true if first match was found
// in stop at first match mode
func (h *StopAtFirstMatchHandler[T]) FoundFirstMatch() bool {
if h.ctx.Err() != nil && h.stopEnabled {
return true
}
return false
}
// Acquire acquires a new work
func (h *StopAtFirstMatchHandler[T]) Acquire() {
switch h.poolType {
case Blocking:
h.sgPool.Add()
case NonBlocking:
h.wgPool.Add(1)
}
}
// Release releases a work
func (h *StopAtFirstMatchHandler[T]) Release() {
switch h.poolType {
case Blocking:
h.sgPool.Done()
case NonBlocking:
h.wgPool.Done()
}
}
func (h *StopAtFirstMatchHandler[T]) Resize(ctx context.Context, size int) error {
if h.sgPool.Size != size {
return h.sgPool.Resize(ctx, size)
}
return nil
}
func (h *StopAtFirstMatchHandler[T]) Size() int {
return h.sgPool.Size
}
// Wait waits for all work to be done
func (h *StopAtFirstMatchHandler[T]) Wait() {
switch h.poolType {
case Blocking:
h.sgPool.Wait()
case NonBlocking:
h.wgPool.Wait()
}
// after waiting it closes the error channel
close(h.ResultChan)
h.internalWg.Wait()
}
// CombinedResults returns the combined results
func (h *StopAtFirstMatchHandler[T]) CombinedResults() []T {
return maps.Keys(h.results)
}