diff --git a/ee/query-service/app/server.go b/ee/query-service/app/server.go index aebd30b3b64e..fc516f49e95a 100644 --- a/ee/query-service/app/server.go +++ b/ee/query-service/app/server.go @@ -257,6 +257,7 @@ func (s *Server) createPublicServer(apiHandler *api.APIHandler, web web.Web) (*h s.config.APIServer.Timeout.Max, ).Wrap) r.Use(middleware.NewLogging(s.signoz.Instrumentation.Logger(), s.config.APIServer.Logging.ExcludedRoutes).Wrap) + r.Use(middleware.NewComment().Wrap) apiHandler.RegisterRoutes(r, am) apiHandler.RegisterLogsRoutes(r, am) diff --git a/pkg/http/middleware/api_key.go b/pkg/http/middleware/api_key.go index 0d53b736bcae..22d088178be4 100644 --- a/pkg/http/middleware/api_key.go +++ b/pkg/http/middleware/api_key.go @@ -9,6 +9,7 @@ import ( "github.com/SigNoz/signoz/pkg/sqlstore" "github.com/SigNoz/signoz/pkg/types" "github.com/SigNoz/signoz/pkg/types/authtypes" + "github.com/SigNoz/signoz/pkg/types/ctxtypes" "github.com/SigNoz/signoz/pkg/valuer" ) @@ -97,7 +98,12 @@ func (a *APIKey) Wrap(next http.Handler) http.Handler { return } - r = r.WithContext(ctx) + comment := ctxtypes.CommentFromContext(ctx) + comment.Set("auth_type", "api_key") + comment.Set("user_id", claims.UserID) + comment.Set("org_id", claims.OrgID) + + r = r.WithContext(ctxtypes.NewContextWithComment(ctx, comment)) next.ServeHTTP(w, r) diff --git a/pkg/http/middleware/auth.go b/pkg/http/middleware/auth.go index 8e6a4e3a03e3..d52b33df225d 100644 --- a/pkg/http/middleware/auth.go +++ b/pkg/http/middleware/auth.go @@ -7,6 +7,7 @@ import ( "github.com/SigNoz/signoz/pkg/sharder" "github.com/SigNoz/signoz/pkg/types" "github.com/SigNoz/signoz/pkg/types/authtypes" + "github.com/SigNoz/signoz/pkg/types/ctxtypes" "github.com/SigNoz/signoz/pkg/valuer" ) @@ -50,7 +51,12 @@ func (a *Auth) Wrap(next http.Handler) http.Handler { return } - r = r.WithContext(ctx) + comment := ctxtypes.CommentFromContext(ctx) + comment.Set("auth_type", "jwt") + comment.Set("user_id", claims.UserID) + comment.Set("org_id", claims.OrgID) + + r = r.WithContext(ctxtypes.NewContextWithComment(ctx, comment)) next.ServeHTTP(w, r) }) diff --git a/pkg/http/middleware/comment.go b/pkg/http/middleware/comment.go new file mode 100644 index 000000000000..e6123a933141 --- /dev/null +++ b/pkg/http/middleware/comment.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "net/http" + + "github.com/SigNoz/signoz/pkg/types/ctxtypes" +) + +type Comment struct{} + +func NewComment() *Comment { + return &Comment{} +} + +func (middleware *Comment) Wrap(next http.Handler) http.Handler { + return http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + + comment := ctxtypes.CommentFromContext(req.Context()) + comment.Merge(ctxtypes.CommentFromHTTPRequest(req)) + + req = req.WithContext(ctxtypes.NewContextWithComment(req.Context(), comment)) + next.ServeHTTP(rw, req) + }) +} diff --git a/pkg/http/middleware/logging.go b/pkg/http/middleware/logging.go index ba3d805758b3..8a967a5c729d 100644 --- a/pkg/http/middleware/logging.go +++ b/pkg/http/middleware/logging.go @@ -2,16 +2,11 @@ package middleware import ( "bytes" - "context" "log/slog" "net" "net/http" - "net/url" - "strings" "time" - "github.com/SigNoz/signoz/pkg/query-service/common" - "github.com/SigNoz/signoz/pkg/types/authtypes" "github.com/gorilla/mux" semconv "go.opentelemetry.io/otel/semconv/v1.26.0" ) @@ -55,9 +50,6 @@ func (middleware *Logging) Wrap(next http.Handler) http.Handler { string(semconv.HTTPRouteKey), path, } - logCommentKVs := middleware.getLogCommentKVs(req) - req = req.WithContext(context.WithValue(req.Context(), common.LogCommentKey, logCommentKVs)) - badResponseBuffer := new(bytes.Buffer) writer := newBadResponseLoggingWriter(rw, badResponseBuffer) next.ServeHTTP(writer, req) @@ -85,67 +77,3 @@ func (middleware *Logging) Wrap(next http.Handler) http.Handler { } }) } - -func (middleware *Logging) getLogCommentKVs(r *http.Request) map[string]string { - referrer := r.Header.Get("Referer") - - var path, dashboardID, alertID, page, client, viewName, tab string - - if referrer != "" { - referrerURL, _ := url.Parse(referrer) - client = "browser" - path = referrerURL.Path - - if strings.Contains(path, "/dashboard") { - // Split the path into segments - pathSegments := strings.Split(referrerURL.Path, "/") - // The dashboard ID should be the segment after "/dashboard/" - // Loop through pathSegments to find "dashboard" and then take the next segment as the ID - for i, segment := range pathSegments { - if segment == "dashboard" && i < len(pathSegments)-1 { - // Return the next segment, which should be the dashboard ID - dashboardID = pathSegments[i+1] - } - } - page = "dashboards" - } else if strings.Contains(path, "/alerts") { - urlParams := referrerURL.Query() - alertID = urlParams.Get("ruleId") - page = "alerts" - } else if strings.Contains(path, "logs") && strings.Contains(path, "explorer") { - page = "logs-explorer" - viewName = referrerURL.Query().Get("viewName") - } else if strings.Contains(path, "/trace") || strings.Contains(path, "traces-explorer") { - page = "traces-explorer" - viewName = referrerURL.Query().Get("viewName") - } else if strings.Contains(path, "/services") { - page = "services" - tab = referrerURL.Query().Get("tab") - if tab == "" { - tab = "OVER_METRICS" - } - } else if strings.Contains(path, "/metrics") { - page = "metrics-explorer" - } - } else { - client = "api" - } - - var email string - claims, err := authtypes.ClaimsFromContext(r.Context()) - if err == nil { - email = claims.Email - } - - kvs := map[string]string{ - "path": path, - "dashboardID": dashboardID, - "alertID": alertID, - "source": page, - "client": client, - "viewName": viewName, - "servicesTab": tab, - "email": email, - } - return kvs -} diff --git a/pkg/querier/api.go b/pkg/querier/api.go index ba49b7dbe392..e79fc388d410 100644 --- a/pkg/querier/api.go +++ b/pkg/querier/api.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "net/http" - "regexp" "runtime/debug" "github.com/SigNoz/signoz/pkg/analytics" @@ -12,6 +11,7 @@ import ( "github.com/SigNoz/signoz/pkg/factory" "github.com/SigNoz/signoz/pkg/http/render" "github.com/SigNoz/signoz/pkg/types/authtypes" + "github.com/SigNoz/signoz/pkg/types/ctxtypes" qbtypes "github.com/SigNoz/signoz/pkg/types/querybuildertypes/querybuildertypesv5" "github.com/SigNoz/signoz/pkg/valuer" "github.com/SigNoz/signoz/pkg/variables" @@ -166,49 +166,9 @@ func (a *API) logEvent(ctx context.Context, referrer string, event *qbtypes.QBEv return } - properties["referrer"] = referrer - - logsExplorerMatched, _ := regexp.MatchString(`/logs/logs-explorer(?:\?.*)?$`, referrer) - traceExplorerMatched, _ := regexp.MatchString(`/traces-explorer(?:\?.*)?$`, referrer) - metricsExplorerMatched, _ := regexp.MatchString(`/metrics-explorer/explorer(?:\?.*)?$`, referrer) - dashboardMatched, _ := regexp.MatchString(`/dashboard/[a-zA-Z0-9\-]+/(new|edit)(?:\?.*)?$`, referrer) - alertMatched, _ := regexp.MatchString(`/alerts/(new|edit)(?:\?.*)?$`, referrer) - - switch { - case dashboardMatched: - properties["module_name"] = "dashboard" - case alertMatched: - properties["module_name"] = "rule" - case metricsExplorerMatched: - properties["module_name"] = "metrics-explorer" - case logsExplorerMatched: - properties["module_name"] = "logs-explorer" - case traceExplorerMatched: - properties["module_name"] = "traces-explorer" - default: - return - } - - if dashboardMatched { - if dashboardIDRegex, err := regexp.Compile(`/dashboard/([a-f0-9\-]+)/`); err == nil { - if matches := dashboardIDRegex.FindStringSubmatch(referrer); len(matches) > 1 { - properties["dashboard_id"] = matches[1] - } - } - - if widgetIDRegex, err := regexp.Compile(`widgetId=([a-f0-9\-]+)`); err == nil { - if matches := widgetIDRegex.FindStringSubmatch(referrer); len(matches) > 1 { - properties["widget_id"] = matches[1] - } - } - } - - if alertMatched { - if alertIDRegex, err := regexp.Compile(`ruleId=(\d+)`); err == nil { - if matches := alertIDRegex.FindStringSubmatch(referrer); len(matches) > 1 { - properties["rule_id"] = matches[1] - } - } + comments := ctxtypes.CommentFromContext(ctx).Map() + for key, value := range comments { + properties[key] = value } if !event.HasData { diff --git a/pkg/query-service/app/clickhouseReader/reader.go b/pkg/query-service/app/clickhouseReader/reader.go index 84426b538212..098662127930 100644 --- a/pkg/query-service/app/clickhouseReader/reader.go +++ b/pkg/query-service/app/clickhouseReader/reader.go @@ -3640,28 +3640,8 @@ func readRowsForTimeSeriesResult(rows driver.Rows, vars []interface{}, columnNam return seriesList, getPersonalisedError(rows.Err()) } -func logCommentKVs(ctx context.Context) map[string]string { - kv := ctx.Value(common.LogCommentKey) - if kv == nil { - return nil - } - logCommentKVs, ok := kv.(map[string]string) - if !ok { - return nil - } - return logCommentKVs -} - // GetTimeSeriesResultV3 runs the query and returns list of time series func (r *ClickHouseReader) GetTimeSeriesResultV3(ctx context.Context, query string) ([]*v3.Series, error) { - - ctxArgs := map[string]interface{}{"query": query} - for k, v := range logCommentKVs(ctx) { - ctxArgs[k] = v - } - - defer utils.Elapsed("GetTimeSeriesResultV3", ctxArgs)() - // Hook up query progress reporting if requested. queryId := ctx.Value("queryId") if queryId != nil { @@ -3725,20 +3705,12 @@ func (r *ClickHouseReader) GetTimeSeriesResultV3(ctx context.Context, query stri // GetListResultV3 runs the query and returns list of rows func (r *ClickHouseReader) GetListResultV3(ctx context.Context, query string) ([]*v3.Row, error) { - - ctxArgs := map[string]interface{}{"query": query} - for k, v := range logCommentKVs(ctx) { - ctxArgs[k] = v - } - - defer utils.Elapsed("GetListResultV3", ctxArgs)() - rows, err := r.db.Query(ctx, query) - if err != nil { zap.L().Error("error while reading time series result", zap.Error(err)) return nil, errors.New(err.Error()) } + defer rows.Close() var ( diff --git a/pkg/query-service/app/server.go b/pkg/query-service/app/server.go index fd67fd61ca46..b33073eb74fc 100644 --- a/pkg/query-service/app/server.go +++ b/pkg/query-service/app/server.go @@ -220,6 +220,7 @@ func (s *Server) createPublicServer(api *APIHandler, web web.Web) (*http.Server, ).Wrap) r.Use(middleware.NewAPIKey(s.signoz.SQLStore, []string{"SIGNOZ-API-KEY"}, s.signoz.Instrumentation.Logger(), s.signoz.Sharder).Wrap) r.Use(middleware.NewLogging(s.signoz.Instrumentation.Logger(), s.config.APIServer.Logging.ExcludedRoutes).Wrap) + r.Use(middleware.NewComment().Wrap) am := middleware.NewAuthZ(s.signoz.Instrumentation.Logger()) diff --git a/pkg/query-service/common/ctx.go b/pkg/query-service/common/ctx.go deleted file mode 100644 index e1599508dad1..000000000000 --- a/pkg/query-service/common/ctx.go +++ /dev/null @@ -1,5 +0,0 @@ -package common - -type LogCommentContextKeyType string - -const LogCommentKey LogCommentContextKeyType = "logComment" diff --git a/pkg/query-service/rules/prom_rule_task.go b/pkg/query-service/rules/prom_rule_task.go index 94fda380faeb..1302cd1775db 100644 --- a/pkg/query-service/rules/prom_rule_task.go +++ b/pkg/query-service/rules/prom_rule_task.go @@ -7,7 +7,7 @@ import ( "sync" "time" - "github.com/SigNoz/signoz/pkg/query-service/common" + "github.com/SigNoz/signoz/pkg/types/ctxtypes" ruletypes "github.com/SigNoz/signoz/pkg/types/ruletypes" "github.com/SigNoz/signoz/pkg/valuer" opentracing "github.com/opentracing/opentracing-go" @@ -369,12 +369,10 @@ func (g *PromRuleTask) Eval(ctx context.Context, ts time.Time) { rule.SetEvaluationTimestamp(t) }(time.Now()) - kvs := map[string]string{ - "alertID": rule.ID(), - "source": "alerts", - "client": "query-service", - } - ctx = context.WithValue(ctx, common.LogCommentKey, kvs) + comment := ctxtypes.CommentFromContext(ctx) + comment.Set("rule_id", rule.ID()) + comment.Set("auth_type", "internal") + ctx = ctxtypes.NewContextWithComment(ctx, comment) _, err := rule.Eval(ctx, ts) if err != nil { diff --git a/pkg/query-service/rules/rule_task.go b/pkg/query-service/rules/rule_task.go index ff0f50d3afc7..be975ad48bd8 100644 --- a/pkg/query-service/rules/rule_task.go +++ b/pkg/query-service/rules/rule_task.go @@ -7,8 +7,8 @@ import ( "sync" "time" - "github.com/SigNoz/signoz/pkg/query-service/common" "github.com/SigNoz/signoz/pkg/query-service/utils/labels" + "github.com/SigNoz/signoz/pkg/types/ctxtypes" ruletypes "github.com/SigNoz/signoz/pkg/types/ruletypes" "github.com/SigNoz/signoz/pkg/valuer" opentracing "github.com/opentracing/opentracing-go" @@ -352,12 +352,10 @@ func (g *RuleTask) Eval(ctx context.Context, ts time.Time) { rule.SetEvaluationTimestamp(t) }(time.Now()) - kvs := map[string]string{ - "alertID": rule.ID(), - "source": "alerts", - "client": "query-service", - } - ctx = context.WithValue(ctx, common.LogCommentKey, kvs) + comment := ctxtypes.CommentFromContext(ctx) + comment.Set("rule_id", rule.ID()) + comment.Set("auth_type", "internal") + ctx = ctxtypes.NewContextWithComment(ctx, comment) _, err := rule.Eval(ctx, ts) if err != nil { diff --git a/pkg/telemetrystore/telemetrystorehook/settings.go b/pkg/telemetrystore/telemetrystorehook/settings.go index 670c2e2b24ef..88a8da34fe35 100644 --- a/pkg/telemetrystore/telemetrystorehook/settings.go +++ b/pkg/telemetrystore/telemetrystorehook/settings.go @@ -2,13 +2,12 @@ package telemetrystorehook import ( "context" - "encoding/json" "strings" "github.com/ClickHouse/clickhouse-go/v2" "github.com/SigNoz/signoz/pkg/factory" - "github.com/SigNoz/signoz/pkg/query-service/common" "github.com/SigNoz/signoz/pkg/telemetrystore" + "github.com/SigNoz/signoz/pkg/types/ctxtypes" ) type provider struct { @@ -32,11 +31,7 @@ func NewSettings(ctx context.Context, providerSettings factory.ProviderSettings, func (h *provider) BeforeQuery(ctx context.Context, _ *telemetrystore.QueryEvent) context.Context { settings := clickhouse.Settings{} - // Apply default settings - logComment := h.getLogComment(ctx) - if logComment != "" { - settings["log_comment"] = logComment - } + settings["log_comment"] = ctxtypes.CommentFromContext(ctx).String() if ctx.Value("enforce_max_result_rows") != nil { settings["max_result_rows"] = h.settings.MaxResultRows @@ -91,22 +86,4 @@ func (h *provider) BeforeQuery(ctx context.Context, _ *telemetrystore.QueryEvent return ctx } -func (h *provider) AfterQuery(ctx context.Context, event *telemetrystore.QueryEvent) { -} - -func (h *provider) getLogComment(ctx context.Context) string { - // Get the key-value pairs from context for log comment - kv := ctx.Value(common.LogCommentKey) - if kv == nil { - return "" - } - - logCommentKVs, ok := kv.(map[string]string) - if !ok { - return "" - } - - logComment, _ := json.Marshal(logCommentKVs) - - return string(logComment) -} +func (h *provider) AfterQuery(ctx context.Context, event *telemetrystore.QueryEvent) {} diff --git a/pkg/types/ctxtypes/comment.go b/pkg/types/ctxtypes/comment.go new file mode 100644 index 000000000000..8c80f11c18e9 --- /dev/null +++ b/pkg/types/ctxtypes/comment.go @@ -0,0 +1,163 @@ +package ctxtypes + +import ( + "context" + "encoding/json" + "net/http" + "net/url" + "regexp" + "sync" +) + +var ( + logsExplorerRegex = regexp.MustCompile(`/logs/logs-explorer(?:\?.*)?$`) + traceExplorerRegex = regexp.MustCompile(`/traces-explorer(?:\?.*)?$`) + metricsExplorerRegex = regexp.MustCompile(`/metrics-explorer/explorer(?:\?.*)?$`) + dashboardRegex = regexp.MustCompile(`/dashboard/[a-zA-Z0-9\-]+/(new|edit)(?:\?.*)?$`) + dashboardIDRegex = regexp.MustCompile(`/dashboard/([a-f0-9\-]+)/`) + widgetIDRegex = regexp.MustCompile(`widgetId=([a-f0-9\-]+)`) + ruleRegex = regexp.MustCompile(`/alerts/(new|edit)(?:\?.*)?$`) + ruleIDRegex = regexp.MustCompile(`ruleId=(\d+)`) +) + +type commentCtxKey struct{} + +type Comment struct { + vals map[string]string + mtx sync.RWMutex +} + +func NewContextWithComment(ctx context.Context, comment *Comment) context.Context { + return context.WithValue(ctx, commentCtxKey{}, comment) +} + +func CommentFromContext(ctx context.Context) *Comment { + comment, ok := ctx.Value(commentCtxKey{}).(*Comment) + if !ok { + return NewComment() + } + + // Return a deep copy of the comment to prevent mutations from affecting the original + copy := NewComment() + copy.Merge(comment.Map()) + return copy +} + +func CommentFromHTTPRequest(req *http.Request) map[string]string { + comments := map[string]string{} + + referrer := req.Header.Get("Referer") + if referrer == "" { + return comments + } + + referrerURL, err := url.Parse(referrer) + if err != nil { + return comments + } + + logsExplorerMatched := logsExplorerRegex.MatchString(referrer) + traceExplorerMatched := traceExplorerRegex.MatchString(referrer) + metricsExplorerMatched := metricsExplorerRegex.MatchString(referrer) + dashboardMatched := dashboardRegex.MatchString(referrer) + ruleMatched := ruleRegex.MatchString(referrer) + + switch { + case dashboardMatched: + comments["module_name"] = "dashboard" + case ruleMatched: + comments["module_name"] = "rule" + case metricsExplorerMatched: + comments["module_name"] = "metrics-explorer" + case logsExplorerMatched: + comments["module_name"] = "logs-explorer" + case traceExplorerMatched: + comments["module_name"] = "traces-explorer" + default: + return comments + } + + if dashboardMatched { + if matches := dashboardIDRegex.FindStringSubmatch(referrer); len(matches) > 1 { + comments["dashboard_id"] = matches[1] + } + + if matches := widgetIDRegex.FindStringSubmatch(referrer); len(matches) > 1 { + comments["widget_id"] = matches[1] + } + } + + if ruleMatched { + if matches := ruleIDRegex.FindStringSubmatch(referrer); len(matches) > 1 { + comments["rule_id"] = matches[1] + } + } + + comments["http_path"] = referrerURL.Path + + return comments +} + +// NewComment creates a new Comment with an empty map. It is safe to use concurrently. +func NewComment() *Comment { + return &Comment{vals: map[string]string{}} +} + +func (comment *Comment) Set(key, value string) { + comment.mtx.Lock() + defer comment.mtx.Unlock() + + comment.vals[key] = value +} + +func (comment *Comment) Merge(vals map[string]string) { + comment.mtx.Lock() + defer comment.mtx.Unlock() + + // If vals is nil, do nothing. Comment should not panic. + if vals == nil { + return + } + + for key, value := range vals { + comment.vals[key] = value + } +} + +func (comment *Comment) Map() map[string]string { + comment.mtx.RLock() + defer comment.mtx.RUnlock() + + copyOfVals := make(map[string]string) + for key, value := range comment.vals { + copyOfVals[key] = value + } + + return copyOfVals +} + +func (comment *Comment) String() string { + comment.mtx.RLock() + defer comment.mtx.RUnlock() + + commentJSON, err := json.Marshal(comment.vals) + if err != nil { + return "{}" + } + + return string(commentJSON) +} + +func (comment *Comment) Equal(other *Comment) bool { + if len(comment.vals) != len(other.vals) { + return false + } + + for key, value := range comment.vals { + if val, ok := other.vals[key]; !ok || val != value { + return false + } + } + + return true +} diff --git a/pkg/types/ctxtypes/comment_test.go b/pkg/types/ctxtypes/comment_test.go new file mode 100644 index 000000000000..35803de54f51 --- /dev/null +++ b/pkg/types/ctxtypes/comment_test.go @@ -0,0 +1,123 @@ +package ctxtypes + +import ( + "context" + "fmt" + "net/http" + "sync" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCommentFromHTTPRequest(t *testing.T) { + testCases := []struct { + name string + req *http.Request + expected map[string]string + }{ + { + name: "EmptyReferer", + req: &http.Request{Header: http.Header{"Referer": {""}}}, + expected: map[string]string{}, + }, + { + name: "ControlCharacterInReferer", + req: &http.Request{Header: http.Header{"Referer": {"https://signoz.io/logs/logs-explorer\x00"}}}, + expected: map[string]string{}, + }, + { + name: "LogsExplorer", + req: &http.Request{Header: http.Header{"Referer": {"https://signoz.io/logs/logs-explorer"}}}, + expected: map[string]string{"http_path": "/logs/logs-explorer", "module_name": "logs-explorer"}, + }, + { + name: "TracesExplorer", + req: &http.Request{Header: http.Header{"Referer": {"https://signoz.io/traces-explorer"}}}, + expected: map[string]string{"http_path": "/traces-explorer", "module_name": "traces-explorer"}, + }, + { + name: "MetricsExplorer", + req: &http.Request{Header: http.Header{"Referer": {"https://signoz.io/metrics-explorer/explorer"}}}, + expected: map[string]string{"http_path": "/metrics-explorer/explorer", "module_name": "metrics-explorer"}, + }, + { + name: "DashboardWithID", + req: &http.Request{Header: http.Header{"Referer": {"https://signoz.io/dashboard/123/new"}}}, + expected: map[string]string{"http_path": "/dashboard/123/new", "module_name": "dashboard", "dashboard_id": "123"}, + }, + { + name: "Rule", + req: &http.Request{Header: http.Header{"Referer": {"https://signoz.io/alerts/new"}}}, + expected: map[string]string{"http_path": "/alerts/new", "module_name": "rule"}, + }, + { + name: "RuleWithID", + req: &http.Request{Header: http.Header{"Referer": {"https://signoz.io/alerts/edit?ruleId=123"}}}, + expected: map[string]string{"http_path": "/alerts/edit", "module_name": "rule", "rule_id": "123"}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := CommentFromHTTPRequest(tc.req) + + assert.True(t, (&Comment{vals: tc.expected}).Equal(&Comment{vals: actual})) + }) + } +} + +func TestCommentFromContext(t *testing.T) { + ctx := context.Background() + comment1 := CommentFromContext(ctx) + assert.True(t, NewComment().Equal(comment1)) + + comment1.Set("k1", "v1") + ctx = NewContextWithComment(ctx, comment1) + actual1 := CommentFromContext(ctx) + assert.True(t, comment1.Equal(actual1)) + + // Get the comment from the context, mutate it, but this time do not set it back in the context + comment2 := CommentFromContext(ctx) + comment2.Set("k2", "v2") + + actual2 := CommentFromContext(ctx) + // Since comment2 was not set back in the context, it should not affect the original comment1 + assert.True(t, comment1.Equal(actual2)) + assert.False(t, comment2.Equal(actual2)) + assert.False(t, comment1.Equal(comment2)) +} + +func TestCommentFromContextConcurrent(t *testing.T) { + comment := NewComment() + comment.Set("k1", "v1") + + ctx := context.Background() + ctx = NewContextWithComment(ctx, comment) + + var wg sync.WaitGroup + ctxs := make([]context.Context, 10) + var mtx sync.Mutex + wg.Add(10) + + for i := 0; i < 10; i++ { + go func(i int) { + defer wg.Done() + comment := CommentFromContext(ctx) + comment.Set("k2", fmt.Sprintf("v%d", i)) + newCtx := NewContextWithComment(ctx, comment) + mtx.Lock() + ctxs[i] = newCtx + mtx.Unlock() + }(i) + } + + wg.Wait() + + for i, ctx := range ctxs { + comment := CommentFromContext(ctx) + assert.Equal(t, len(comment.vals), 2) + assert.Equal(t, comment.vals["k1"], "v1") + assert.Equal(t, comment.vals["k2"], fmt.Sprintf("v%d", i)) + } +}