From d5066ddfcab6aeaead914274cb3f59cc8b3719c5 Mon Sep 17 00:00:00 2001 From: Mzack9999 Date: Fri, 9 May 2025 00:27:59 +0200 Subject: [PATCH] fixing pg dialers --- pkg/js/libs/postgres/postgres.go | 2 +- pkg/js/utils/pgwrap/pgwrap.go | 42 +++++++++++++++++++++----------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/pkg/js/libs/postgres/postgres.go b/pkg/js/libs/postgres/postgres.go index 2537269fd..b617fd92b 100644 --- a/pkg/js/libs/postgres/postgres.go +++ b/pkg/js/libs/postgres/postgres.go @@ -123,7 +123,7 @@ func executeQuery(executionId string, host string, port int, username string, pa target := net.JoinHostPort(host, fmt.Sprintf("%d", port)) - connStr := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable", username, password, target, dbName) + connStr := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable&executionId=%s", username, password, target, dbName, executionId) db, err := sql.Open(pgwrap.PGWrapDriver, connStr) if err != nil { return nil, err diff --git a/pkg/js/utils/pgwrap/pgwrap.go b/pkg/js/utils/pgwrap/pgwrap.go index 765558ce2..bc06c02d2 100644 --- a/pkg/js/utils/pgwrap/pgwrap.go +++ b/pkg/js/utils/pgwrap/pgwrap.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "fmt" "net" + "net/url" "time" "github.com/lib/pq" @@ -17,26 +18,34 @@ const ( PGWrapDriver = "pgwrap" ) -// nolint type pgDial struct { - fd *fastdialer.Dialer + executionId string } -// nolint func (p *pgDial) Dial(network, address string) (net.Conn, error) { - return p.fd.Dial(context.TODO(), network, address) + dialers := protocolstate.GetDialersWithId(p.executionId) + if dialers == nil { + return nil, fmt.Errorf("fastdialer not initialized") + } + return dialers.Fastdialer.Dial(context.TODO(), network, address) } -// nolint func (p *pgDial) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { + dialers := protocolstate.GetDialersWithId(p.executionId) + if dialers == nil { + return nil, fmt.Errorf("fastdialer not initialized") + } ctx, cancel := context.WithTimeoutCause(context.Background(), timeout, fastdialer.ErrDialTimeout) defer cancel() - return p.fd.Dial(ctx, network, address) + return dialers.Fastdialer.Dial(ctx, network, address) } -// nolint func (p *pgDial) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - return p.fd.Dial(ctx, network, address) + dialers := protocolstate.GetDialersWithId(p.executionId) + if dialers == nil { + return nil, fmt.Errorf("fastdialer not initialized") + } + return dialers.Fastdialer.Dial(ctx, network, address) } // Unfortunately lib/pq does not provide easy to customize or @@ -50,13 +59,18 @@ type PgDriver struct{} // Most users should only use it through database/sql package from the standard // library. func (d PgDriver) Open(name string) (driver.Conn, error) { - // Get the fastdialer instance from protocolstate - // TODO: find a way to obtain context from here - dialers := protocolstate.GetDialersWithId("") - if dialers == nil { - return nil, fmt.Errorf("fastdialer not initialized") + // Parse the connection string to get executionId + u, err := url.Parse(name) + if err != nil { + return nil, fmt.Errorf("invalid connection string: %v", err) } - return pq.DialOpen(&pgDial{fd: dialers.Fastdialer}, name) + values := u.Query() + executionId := values.Get("executionId") + // Remove executionId from the connection string + values.Del("executionId") + u.RawQuery = values.Encode() + + return pq.DialOpen(&pgDial{executionId: executionId}, u.String()) } func init() {