323 lines
9.0 KiB
Go
Raw Normal View History

package httpclientpool
import (
"context"
"crypto/tls"
"net"
"net/http"
2021-01-15 14:17:34 +05:30
"net/http/cookiejar"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/pkg/errors"
2021-09-03 17:25:50 +03:00
"golang.org/x/net/proxy"
"golang.org/x/net/publicsuffix"
"github.com/projectdiscovery/fastdialer/fastdialer"
"github.com/projectdiscovery/nuclei/v2/pkg/protocols/common/protocolstate"
"github.com/projectdiscovery/nuclei/v2/pkg/protocols/utils"
"github.com/projectdiscovery/nuclei/v2/pkg/types"
"github.com/projectdiscovery/rawhttp"
"github.com/projectdiscovery/retryablehttp-go"
)
var (
2021-06-26 23:49:31 +08:00
// Dialer is a copy of the fastdialer from protocolstate
2021-06-14 17:14:16 +05:30
Dialer *fastdialer.Dialer
rawHttpClient *rawhttp.Client
forceMaxRedirects int
poolMutex *sync.RWMutex
normalClient *retryablehttp.Client
clientPool map[string]*retryablehttp.Client
)
// Init initializes the clientpool implementation
func Init(options *types.Options) error {
2021-09-07 17:31:46 +03:00
// Don't create clients if already created in the past.
if normalClient != nil {
return nil
}
if options.ShouldFollowHTTPRedirects() {
forceMaxRedirects = options.MaxRedirects
}
poolMutex = &sync.RWMutex{}
clientPool = make(map[string]*retryablehttp.Client)
2020-12-29 11:42:46 +05:30
client, err := wrappedGet(options, &Configuration{})
if err != nil {
return err
}
2020-12-29 11:42:46 +05:30
normalClient = client
return nil
}
2021-09-03 17:25:50 +03:00
// ConnectionConfiguration contains the custom configuration options for a connection
2021-08-08 21:52:01 +02:00
type ConnectionConfiguration struct {
// DisableKeepAlive of the connection
DisableKeepAlive bool
Cookiejar *cookiejar.Jar
2021-08-08 21:52:01 +02:00
}
// Configuration contains the custom configuration options for a client
type Configuration struct {
// Threads contains the threads for the client
Threads int
// MaxRedirects is the maximum number of redirects to follow
MaxRedirects int
// NoTimeout disables http request timeout for context based usage
NoTimeout bool
2021-02-26 13:13:11 +05:30
// CookieReuse enables cookie reuse for the http client (cookiejar impl)
CookieReuse bool
// FollowRedirects specifies the redirects flow
RedirectFlow RedirectFlow
2021-08-08 21:52:01 +02:00
// Connection defines custom connection configuration
Connection *ConnectionConfiguration
}
// Hash returns the hash of the configuration to allow client pooling
func (c *Configuration) Hash() string {
builder := &strings.Builder{}
builder.Grow(16)
builder.WriteString("t")
builder.WriteString(strconv.Itoa(c.Threads))
builder.WriteString("m")
builder.WriteString(strconv.Itoa(c.MaxRedirects))
builder.WriteString("n")
builder.WriteString(strconv.FormatBool(c.NoTimeout))
builder.WriteString("f")
builder.WriteString(strconv.Itoa(int(c.RedirectFlow)))
2021-01-15 14:17:34 +05:30
builder.WriteString("r")
builder.WriteString(strconv.FormatBool(c.CookieReuse))
2021-08-08 21:52:01 +02:00
builder.WriteString("c")
builder.WriteString(strconv.FormatBool(c.Connection != nil))
hash := builder.String()
return hash
}
2021-09-03 17:25:50 +03:00
// HasStandardOptions checks whether the configuration requires custom settings
2021-08-08 21:52:01 +02:00
func (c *Configuration) HasStandardOptions() bool {
return c.Threads == 0 && c.MaxRedirects == 0 && c.RedirectFlow == DontFollowRedirect && !c.CookieReuse && c.Connection == nil && !c.NoTimeout
2021-08-08 21:52:01 +02:00
}
// GetRawHTTP returns the rawhttp request client
2021-06-25 08:16:54 +02:00
func GetRawHTTP(options *types.Options) *rawhttp.Client {
if rawHttpClient == nil {
rawHttpOptions := rawhttp.DefaultOptions
2022-03-10 13:49:17 +08:00
if types.ProxyURL != "" {
rawHttpOptions.Proxy = types.ProxyURL
} else if types.ProxySocksURL != "" {
rawHttpOptions.Proxy = types.ProxySocksURL
} else if Dialer != nil {
rawHttpOptions.FastDialer = Dialer
2022-03-10 13:49:17 +08:00
}
rawHttpOptions.Timeout = time.Duration(options.Timeout) * time.Second
rawHttpClient = rawhttp.NewClient(rawHttpOptions)
}
return rawHttpClient
}
// Get creates or gets a client for the protocol based on custom configuration
func Get(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {
2021-08-08 21:52:01 +02:00
if configuration.HasStandardOptions() {
return normalClient, nil
}
2020-12-29 11:42:46 +05:30
return wrappedGet(options, configuration)
}
2021-06-26 23:49:31 +08:00
// wrappedGet wraps a get operation without normal client check
2020-12-29 11:42:46 +05:30
func wrappedGet(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {
var err error
if Dialer == nil {
Dialer = protocolstate.Dialer
}
hash := configuration.Hash()
poolMutex.RLock()
if client, ok := clientPool[hash]; ok {
poolMutex.RUnlock()
return client, nil
}
poolMutex.RUnlock()
// Multiple Host
retryableHttpOptions := retryablehttp.DefaultOptionsSpraying
disableKeepAlives := true
maxIdleConns := 0
maxConnsPerHost := 0
maxIdleConnsPerHost := -1
if configuration.Threads > 0 {
// Single host
retryableHttpOptions = retryablehttp.DefaultOptionsSingle
disableKeepAlives = false
maxIdleConnsPerHost = 500
maxConnsPerHost = 500
}
retryableHttpOptions.RetryWaitMax = 10 * time.Second
retryableHttpOptions.RetryMax = options.Retries
redirectFlow := configuration.RedirectFlow
maxRedirects := configuration.MaxRedirects
if forceMaxRedirects > 0 {
// by default we enable general redirects following
switch {
case options.FollowHostRedirects:
redirectFlow = FollowSameHostRedirect
default:
redirectFlow = FollowAllRedirect
}
maxRedirects = forceMaxRedirects
}
2022-04-27 11:19:44 -05:00
if options.DisableRedirects {
options.FollowRedirects = false
options.FollowHostRedirects = false
redirectFlow = DontFollowRedirect
2022-04-27 11:19:44 -05:00
maxRedirects = 0
}
2021-08-08 21:52:01 +02:00
// override connection's settings if required
if configuration.Connection != nil {
disableKeepAlives = configuration.Connection.DisableKeepAlive
}
// Set the base TLS configuration definition
tlsConfig := &tls.Config{
Renegotiation: tls.RenegotiateOnceAsClient,
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS10,
}
if options.SNI != "" {
tlsConfig.ServerName = options.SNI
}
2021-10-27 12:11:42 -04:00
// Add the client certificate authentication to the request if it's configured
tlsConfig, err = utils.AddConfiguredClientCertToRequest(tlsConfig, options)
if err != nil {
return nil, errors.Wrap(err, "could not create client certificate")
}
transport := &http.Transport{
DialContext: Dialer.Dial,
DialTLSContext: Dialer.DialTLS,
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
MaxConnsPerHost: maxConnsPerHost,
TLSClientConfig: tlsConfig,
DisableKeepAlives: disableKeepAlives,
}
if types.ProxyURL != "" {
if proxyURL, err := url.Parse(types.ProxyURL); err == nil {
transport.Proxy = http.ProxyURL(proxyURL)
}
} else if types.ProxySocksURL != "" {
socksURL, proxyErr := url.Parse(types.ProxySocksURL)
if proxyErr != nil {
return nil, proxyErr
}
dialer, err := proxy.FromURL(socksURL, proxy.Direct)
if err != nil {
return nil, err
}
dc := dialer.(interface {
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
})
if proxyErr == nil {
transport.DialContext = dc.DialContext
transport.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
// upgrade proxy connection to tls
conn, err := dc.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
return tls.Client(conn, tlsConfig), nil
}
}
}
2021-01-15 14:17:34 +05:30
var jar *cookiejar.Jar
if configuration.Connection != nil && configuration.Connection.Cookiejar != nil {
jar = configuration.Connection.Cookiejar
} else if configuration.CookieReuse {
2021-01-15 14:17:34 +05:30
if jar, err = cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}); err != nil {
return nil, errors.Wrap(err, "could not create cookiejar")
}
}
httpclient := &http.Client{
Transport: transport,
CheckRedirect: makeCheckRedirectFunc(redirectFlow, maxRedirects),
}
if !configuration.NoTimeout {
httpclient.Timeout = time.Duration(options.Timeout) * time.Second
}
client := retryablehttp.NewWithHTTPClient(httpclient, retryableHttpOptions)
2021-01-15 14:17:34 +05:30
if jar != nil {
client.HTTPClient.Jar = jar
}
client.CheckRetry = retryablehttp.HostSprayRetryPolicy()
2021-01-15 14:17:34 +05:30
// Only add to client pool if we don't have a cookie jar in place.
if jar == nil {
poolMutex.Lock()
clientPool[hash] = client
poolMutex.Unlock()
}
return client, nil
}
type RedirectFlow uint8
const (
DontFollowRedirect RedirectFlow = iota
FollowSameHostRedirect
FollowAllRedirect
)
const defaultMaxRedirects = 10
type checkRedirectFunc func(req *http.Request, via []*http.Request) error
func makeCheckRedirectFunc(redirectType RedirectFlow, maxRedirects int) checkRedirectFunc {
return func(req *http.Request, via []*http.Request) error {
switch redirectType {
case DontFollowRedirect:
return http.ErrUseLastResponse
case FollowSameHostRedirect:
var newHost = req.URL.Host
var oldHost = via[0].Host
if oldHost == "" {
oldHost = via[0].URL.Host
}
if newHost != oldHost {
// Tell the http client to not follow redirect
return http.ErrUseLastResponse
}
return checkMaxRedirects(req, via, maxRedirects)
case FollowAllRedirect:
return checkMaxRedirects(req, via, maxRedirects)
}
return nil
}
}
func checkMaxRedirects(req *http.Request, via []*http.Request, maxRedirects int) error {
if maxRedirects == 0 {
if len(via) > defaultMaxRedirects {
return http.ErrUseLastResponse
}
return nil
}
if len(via) > maxRedirects {
return http.ErrUseLastResponse
}
return nil
}