diff --git a/v2/pkg/protocols/http/clientpool/clientpool.go b/v2/pkg/protocols/http/clientpool/clientpool.go new file mode 100644 index 000000000..747b55371 --- /dev/null +++ b/v2/pkg/protocols/http/clientpool/clientpool.go @@ -0,0 +1,171 @@ +package clientpool + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/pkg/errors" + "github.com/projectdiscovery/fastdialer/fastdialer" + "github.com/projectdiscovery/nuclei/v2/pkg/types" + "github.com/projectdiscovery/retryablehttp-go" + "golang.org/x/net/proxy" +) + +var ( + dialer *fastdialer.Dialer + poolMutex *sync.RWMutex + clientPool map[string]*retryablehttp.Client +) + +func init() { + poolMutex = &sync.RWMutex{} + clientPool = make(map[string]*retryablehttp.Client) +} + +// Configuration contains the custom configuration options for a client +type Configuration struct { + // Threads contains the threads for the client + Threads int + // MaxRedirects is the maximum number of redirects to follow + MaxRedirects int + // FollowRedirects specifies whether to follow redirects + FollowRedirects bool +} + +// Hash returns the hash of the configuration to allow client pooling +func (c *Configuration) Hash() string { + builder := &strings.Builder{} + builder.WriteString("t") + builder.WriteString(strconv.Itoa(c.Threads)) + builder.WriteString("m") + builder.WriteString(strconv.Itoa(c.MaxRedirects)) + builder.WriteString("f") + builder.WriteString(strconv.FormatBool(c.FollowRedirects)) + hash := builder.String() + return hash +} + +// Get creates or gets a client for the protocol based on custom configuration +func Get(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) { + var proxyURL *url.URL + var err error + + if dialer == nil { + dialer, err = fastdialer.NewDialer(fastdialer.DefaultOptions) + } + if err != nil { + return nil, errors.Wrap(err, "could not create dialer") + } + + hash := configuration.Hash() + poolMutex.RLock() + if client, ok := clientPool[hash]; ok { + poolMutex.RUnlock() + return client, nil + } + poolMutex.RUnlock() + + if options.ProxyURL != "" { + proxyURL, err = url.Parse(options.ProxyURL) + } + if err != nil { + return nil, err + } + + // Multiple Host + retryablehttpOptions := retryablehttp.DefaultOptionsSpraying + disableKeepAlives := true + maxIdleConns := 0 + maxConnsPerHost := 0 + maxIdleConnsPerHost := -1 + + if configuration.Threads > 0 { + // Single host + retryablehttpOptions = retryablehttp.DefaultOptionsSingle + disableKeepAlives = false + maxIdleConnsPerHost = 500 + maxConnsPerHost = 500 + } + + retryablehttpOptions.RetryWaitMax = 10 * time.Second + retryablehttpOptions.RetryMax = options.Retries + followRedirects := configuration.FollowRedirects + maxRedirects := configuration.MaxRedirects + + transport := &http.Transport{ + DialContext: dialer.Dial, + MaxIdleConns: maxIdleConns, + MaxIdleConnsPerHost: maxIdleConnsPerHost, + MaxConnsPerHost: maxConnsPerHost, + TLSClientConfig: &tls.Config{ + Renegotiation: tls.RenegotiateOnceAsClient, + InsecureSkipVerify: true, + }, + DisableKeepAlives: disableKeepAlives, + } + + // Attempts to overwrite the dial function with the socks proxied version + if options.ProxySocksURL != "" { + var proxyAuth *proxy.Auth + + socksURL, err := url.Parse(options.ProxySocksURL) + if err == nil { + proxyAuth = &proxy.Auth{} + proxyAuth.User = socksURL.User.Username() + proxyAuth.Password, _ = socksURL.User.Password() + } + dialer, err := proxy.SOCKS5("tcp", fmt.Sprintf("%s:%s", socksURL.Hostname(), socksURL.Port()), proxyAuth, proxy.Direct) + dc := dialer.(interface { + DialContext(ctx context.Context, network, addr string) (net.Conn, error) + }) + if err == nil { + transport.DialContext = dc.DialContext + } + } + if proxyURL != nil { + transport.Proxy = http.ProxyURL(proxyURL) + } + + client := retryablehttp.NewWithHTTPClient(&http.Client{ + Transport: transport, + Timeout: time.Duration(options.Timeout) * time.Second, + CheckRedirect: makeCheckRedirectFunc(followRedirects, maxRedirects), + }, retryablehttpOptions) + + poolMutex.Lock() + clientPool[hash] = client + poolMutex.Unlock() + return client, nil +} + +const defaultMaxRedirects = 10 + +type checkRedirectFunc func(req *http.Request, via []*http.Request) error + +func makeCheckRedirectFunc(followRedirects bool, maxRedirects int) checkRedirectFunc { + return func(req *http.Request, via []*http.Request) error { + if !followRedirects { + return http.ErrUseLastResponse + } + + if maxRedirects == 0 { + if len(via) > defaultMaxRedirects { + return http.ErrUseLastResponse + } + return nil + } + + if len(via) > maxRedirects { + return http.ErrUseLastResponse + } + return nil + } +}