2024-05-25 00:29:04 +05:30
|
|
|
package pgwrap
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"database/sql"
|
|
|
|
|
"database/sql/driver"
|
2025-07-09 14:47:26 -05:00
|
|
|
"fmt"
|
2024-05-25 00:29:04 +05:30
|
|
|
"net"
|
2025-07-09 14:47:26 -05:00
|
|
|
"net/url"
|
2024-05-25 00:29:04 +05:30
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
"github.com/lib/pq"
|
|
|
|
|
"github.com/projectdiscovery/fastdialer/fastdialer"
|
|
|
|
|
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
PGWrapDriver = "pgwrap"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
type pgDial struct {
|
2025-07-09 14:47:26 -05:00
|
|
|
executionId string
|
2024-05-25 00:29:04 +05:30
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (p *pgDial) Dial(network, address string) (net.Conn, error) {
|
2025-07-09 14:47:26 -05:00
|
|
|
dialers := protocolstate.GetDialersWithId(p.executionId)
|
|
|
|
|
if dialers == nil {
|
|
|
|
|
return nil, fmt.Errorf("dialers not initialized for %s", p.executionId)
|
|
|
|
|
}
|
|
|
|
|
return dialers.Fastdialer.Dial(context.TODO(), network, address)
|
2024-05-25 00:29:04 +05:30
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (p *pgDial) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
|
2025-07-09 14:47:26 -05:00
|
|
|
dialers := protocolstate.GetDialersWithId(p.executionId)
|
|
|
|
|
if dialers == nil {
|
|
|
|
|
return nil, fmt.Errorf("dialers not initialized for %s", p.executionId)
|
|
|
|
|
}
|
2024-05-25 00:29:04 +05:30
|
|
|
ctx, cancel := context.WithTimeoutCause(context.Background(), timeout, fastdialer.ErrDialTimeout)
|
|
|
|
|
defer cancel()
|
2025-07-09 14:47:26 -05:00
|
|
|
return dialers.Fastdialer.Dial(ctx, network, address)
|
2024-05-25 00:29:04 +05:30
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (p *pgDial) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
2025-07-09 14:47:26 -05:00
|
|
|
dialers := protocolstate.GetDialersWithId(p.executionId)
|
|
|
|
|
if dialers == nil {
|
|
|
|
|
return nil, fmt.Errorf("dialers not initialized for %s", p.executionId)
|
|
|
|
|
}
|
|
|
|
|
return dialers.Fastdialer.Dial(ctx, network, address)
|
2024-05-25 00:29:04 +05:30
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Unfortunately lib/pq does not provide easy to customize or
|
|
|
|
|
// replace dialer so we need to hijack it by wrapping it in our own
|
|
|
|
|
// driver and register it as postgres driver
|
|
|
|
|
|
|
|
|
|
// PgDriver is the Postgres database driver.
|
|
|
|
|
type PgDriver struct{}
|
|
|
|
|
|
|
|
|
|
// Open opens a new connection to the database. name is a connection string.
|
|
|
|
|
// Most users should only use it through database/sql package from the standard
|
|
|
|
|
// library.
|
|
|
|
|
func (d PgDriver) Open(name string) (driver.Conn, error) {
|
2025-07-09 14:47:26 -05:00
|
|
|
// Parse the connection string to get executionId
|
|
|
|
|
u, err := url.Parse(name)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, fmt.Errorf("invalid connection string: %v", err)
|
|
|
|
|
}
|
|
|
|
|
values := u.Query()
|
|
|
|
|
executionId := values.Get("executionId")
|
|
|
|
|
// Remove executionId from the connection string
|
|
|
|
|
values.Del("executionId")
|
|
|
|
|
u.RawQuery = values.Encode()
|
|
|
|
|
|
|
|
|
|
return pq.DialOpen(&pgDial{executionId: executionId}, u.String())
|
2024-05-25 00:29:04 +05:30
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func init() {
|
|
|
|
|
sql.Register(PGWrapDriver, &PgDriver{})
|
|
|
|
|
}
|