fixing pg dialers

This commit is contained in:
Mzack9999 2025-05-09 00:27:59 +02:00
parent 781ca16eeb
commit d5066ddfca
2 changed files with 29 additions and 15 deletions

View File

@ -123,7 +123,7 @@ func executeQuery(executionId string, host string, port int, username string, pa
target := net.JoinHostPort(host, fmt.Sprintf("%d", port))
connStr := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable", username, password, target, dbName)
connStr := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable&executionId=%s", username, password, target, dbName, executionId)
db, err := sql.Open(pgwrap.PGWrapDriver, connStr)
if err != nil {
return nil, err

View File

@ -6,6 +6,7 @@ import (
"database/sql/driver"
"fmt"
"net"
"net/url"
"time"
"github.com/lib/pq"
@ -17,26 +18,34 @@ const (
PGWrapDriver = "pgwrap"
)
// nolint
type pgDial struct {
fd *fastdialer.Dialer
executionId string
}
// nolint
func (p *pgDial) Dial(network, address string) (net.Conn, error) {
return p.fd.Dial(context.TODO(), network, address)
dialers := protocolstate.GetDialersWithId(p.executionId)
if dialers == nil {
return nil, fmt.Errorf("fastdialer not initialized")
}
return dialers.Fastdialer.Dial(context.TODO(), network, address)
}
// nolint
func (p *pgDial) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
dialers := protocolstate.GetDialersWithId(p.executionId)
if dialers == nil {
return nil, fmt.Errorf("fastdialer not initialized")
}
ctx, cancel := context.WithTimeoutCause(context.Background(), timeout, fastdialer.ErrDialTimeout)
defer cancel()
return p.fd.Dial(ctx, network, address)
return dialers.Fastdialer.Dial(ctx, network, address)
}
// nolint
func (p *pgDial) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return p.fd.Dial(ctx, network, address)
dialers := protocolstate.GetDialersWithId(p.executionId)
if dialers == nil {
return nil, fmt.Errorf("fastdialer not initialized")
}
return dialers.Fastdialer.Dial(ctx, network, address)
}
// Unfortunately lib/pq does not provide easy to customize or
@ -50,13 +59,18 @@ type PgDriver struct{}
// Most users should only use it through database/sql package from the standard
// library.
func (d PgDriver) Open(name string) (driver.Conn, error) {
// Get the fastdialer instance from protocolstate
// TODO: find a way to obtain context from here
dialers := protocolstate.GetDialersWithId("")
if dialers == nil {
return nil, fmt.Errorf("fastdialer not initialized")
// Parse the connection string to get executionId
u, err := url.Parse(name)
if err != nil {
return nil, fmt.Errorf("invalid connection string: %v", err)
}
return pq.DialOpen(&pgDial{fd: dialers.Fastdialer}, name)
values := u.Query()
executionId := values.Get("executionId")
// Remove executionId from the connection string
values.Del("executionId")
u.RawQuery = values.Encode()
return pq.DialOpen(&pgDial{executionId: executionId}, u.String())
}
func init() {