172 lines
4.3 KiB
Go
Raw Normal View History

package clientpool
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/pkg/errors"
"github.com/projectdiscovery/fastdialer/fastdialer"
"github.com/projectdiscovery/nuclei/v2/pkg/types"
"github.com/projectdiscovery/retryablehttp-go"
"golang.org/x/net/proxy"
)
var (
dialer *fastdialer.Dialer
poolMutex *sync.RWMutex
clientPool map[string]*retryablehttp.Client
)
func init() {
poolMutex = &sync.RWMutex{}
clientPool = make(map[string]*retryablehttp.Client)
}
// Configuration contains the custom configuration options for a client
type Configuration struct {
// Threads contains the threads for the client
Threads int
// MaxRedirects is the maximum number of redirects to follow
MaxRedirects int
// FollowRedirects specifies whether to follow redirects
FollowRedirects bool
}
// Hash returns the hash of the configuration to allow client pooling
func (c *Configuration) Hash() string {
builder := &strings.Builder{}
builder.WriteString("t")
builder.WriteString(strconv.Itoa(c.Threads))
builder.WriteString("m")
builder.WriteString(strconv.Itoa(c.MaxRedirects))
builder.WriteString("f")
builder.WriteString(strconv.FormatBool(c.FollowRedirects))
hash := builder.String()
return hash
}
// Get creates or gets a client for the protocol based on custom configuration
func Get(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {
var proxyURL *url.URL
var err error
if dialer == nil {
dialer, err = fastdialer.NewDialer(fastdialer.DefaultOptions)
}
if err != nil {
return nil, errors.Wrap(err, "could not create dialer")
}
hash := configuration.Hash()
poolMutex.RLock()
if client, ok := clientPool[hash]; ok {
poolMutex.RUnlock()
return client, nil
}
poolMutex.RUnlock()
if options.ProxyURL != "" {
proxyURL, err = url.Parse(options.ProxyURL)
}
if err != nil {
return nil, err
}
// Multiple Host
retryablehttpOptions := retryablehttp.DefaultOptionsSpraying
disableKeepAlives := true
maxIdleConns := 0
maxConnsPerHost := 0
maxIdleConnsPerHost := -1
if configuration.Threads > 0 {
// Single host
retryablehttpOptions = retryablehttp.DefaultOptionsSingle
disableKeepAlives = false
maxIdleConnsPerHost = 500
maxConnsPerHost = 500
}
retryablehttpOptions.RetryWaitMax = 10 * time.Second
retryablehttpOptions.RetryMax = options.Retries
followRedirects := configuration.FollowRedirects
maxRedirects := configuration.MaxRedirects
transport := &http.Transport{
DialContext: dialer.Dial,
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
MaxConnsPerHost: maxConnsPerHost,
TLSClientConfig: &tls.Config{
Renegotiation: tls.RenegotiateOnceAsClient,
InsecureSkipVerify: true,
},
DisableKeepAlives: disableKeepAlives,
}
// Attempts to overwrite the dial function with the socks proxied version
if options.ProxySocksURL != "" {
var proxyAuth *proxy.Auth
socksURL, err := url.Parse(options.ProxySocksURL)
if err == nil {
proxyAuth = &proxy.Auth{}
proxyAuth.User = socksURL.User.Username()
proxyAuth.Password, _ = socksURL.User.Password()
}
dialer, err := proxy.SOCKS5("tcp", fmt.Sprintf("%s:%s", socksURL.Hostname(), socksURL.Port()), proxyAuth, proxy.Direct)
dc := dialer.(interface {
DialContext(ctx context.Context, network, addr string) (net.Conn, error)
})
if err == nil {
transport.DialContext = dc.DialContext
}
}
if proxyURL != nil {
transport.Proxy = http.ProxyURL(proxyURL)
}
client := retryablehttp.NewWithHTTPClient(&http.Client{
Transport: transport,
Timeout: time.Duration(options.Timeout) * time.Second,
CheckRedirect: makeCheckRedirectFunc(followRedirects, maxRedirects),
}, retryablehttpOptions)
poolMutex.Lock()
clientPool[hash] = client
poolMutex.Unlock()
return client, nil
}
const defaultMaxRedirects = 10
type checkRedirectFunc func(req *http.Request, via []*http.Request) error
func makeCheckRedirectFunc(followRedirects bool, maxRedirects int) checkRedirectFunc {
return func(req *http.Request, via []*http.Request) error {
if !followRedirects {
return http.ErrUseLastResponse
}
if maxRedirects == 0 {
if len(via) > defaultMaxRedirects {
return http.ErrUseLastResponse
}
return nil
}
if len(via) > maxRedirects {
return http.ErrUseLastResponse
}
return nil
}
}