mirror of
https://github.com/projectdiscovery/nuclei.git
synced 2025-12-24 19:45:41 +00:00
fixing pg dialers
This commit is contained in:
parent
781ca16eeb
commit
d5066ddfca
@ -123,7 +123,7 @@ func executeQuery(executionId string, host string, port int, username string, pa
|
||||
|
||||
target := net.JoinHostPort(host, fmt.Sprintf("%d", port))
|
||||
|
||||
connStr := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable", username, password, target, dbName)
|
||||
connStr := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable&executionId=%s", username, password, target, dbName, executionId)
|
||||
db, err := sql.Open(pgwrap.PGWrapDriver, connStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
@ -17,26 +18,34 @@ const (
|
||||
PGWrapDriver = "pgwrap"
|
||||
)
|
||||
|
||||
// nolint
|
||||
type pgDial struct {
|
||||
fd *fastdialer.Dialer
|
||||
executionId string
|
||||
}
|
||||
|
||||
// nolint
|
||||
func (p *pgDial) Dial(network, address string) (net.Conn, error) {
|
||||
return p.fd.Dial(context.TODO(), network, address)
|
||||
dialers := protocolstate.GetDialersWithId(p.executionId)
|
||||
if dialers == nil {
|
||||
return nil, fmt.Errorf("fastdialer not initialized")
|
||||
}
|
||||
return dialers.Fastdialer.Dial(context.TODO(), network, address)
|
||||
}
|
||||
|
||||
// nolint
|
||||
func (p *pgDial) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
|
||||
dialers := protocolstate.GetDialersWithId(p.executionId)
|
||||
if dialers == nil {
|
||||
return nil, fmt.Errorf("fastdialer not initialized")
|
||||
}
|
||||
ctx, cancel := context.WithTimeoutCause(context.Background(), timeout, fastdialer.ErrDialTimeout)
|
||||
defer cancel()
|
||||
return p.fd.Dial(ctx, network, address)
|
||||
return dialers.Fastdialer.Dial(ctx, network, address)
|
||||
}
|
||||
|
||||
// nolint
|
||||
func (p *pgDial) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return p.fd.Dial(ctx, network, address)
|
||||
dialers := protocolstate.GetDialersWithId(p.executionId)
|
||||
if dialers == nil {
|
||||
return nil, fmt.Errorf("fastdialer not initialized")
|
||||
}
|
||||
return dialers.Fastdialer.Dial(ctx, network, address)
|
||||
}
|
||||
|
||||
// Unfortunately lib/pq does not provide easy to customize or
|
||||
@ -50,13 +59,18 @@ type PgDriver struct{}
|
||||
// Most users should only use it through database/sql package from the standard
|
||||
// library.
|
||||
func (d PgDriver) Open(name string) (driver.Conn, error) {
|
||||
// Get the fastdialer instance from protocolstate
|
||||
// TODO: find a way to obtain context from here
|
||||
dialers := protocolstate.GetDialersWithId("")
|
||||
if dialers == nil {
|
||||
return nil, fmt.Errorf("fastdialer not initialized")
|
||||
// Parse the connection string to get executionId
|
||||
u, err := url.Parse(name)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid connection string: %v", err)
|
||||
}
|
||||
return pq.DialOpen(&pgDial{fd: dialers.Fastdialer}, name)
|
||||
values := u.Query()
|
||||
executionId := values.Get("executionId")
|
||||
// Remove executionId from the connection string
|
||||
values.Del("executionId")
|
||||
u.RawQuery = values.Encode()
|
||||
|
||||
return pq.DialOpen(&pgDial{executionId: executionId}, u.String())
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user