diff --git a/ee/query-service/app/server.go b/ee/query-service/app/server.go index b83dd4f51b10..a963cf4e33f2 100644 --- a/ee/query-service/app/server.go +++ b/ee/query-service/app/server.go @@ -8,6 +8,8 @@ import ( "net/http" _ "net/http/pprof" // http profiler + "github.com/SigNoz/signoz/pkg/ruler/rulestore/sqlrulestore" + "github.com/gorilla/handlers" "github.com/SigNoz/signoz/ee/query-service/app/api" @@ -334,6 +336,8 @@ func makeRulesManager( querier querier.Querier, logger *slog.Logger, ) (*baserules.Manager, error) { + ruleStore := sqlrulestore.NewRuleStore(sqlstore) + maintenanceStore := sqlrulestore.NewMaintenanceStore(sqlstore) // create manager opts managerOpts := &baserules.ManagerOptions{ TelemetryStore: telemetryStore, @@ -348,8 +352,10 @@ func makeRulesManager( PrepareTaskFunc: rules.PrepareTaskFunc, PrepareTestRuleFunc: rules.TestNotification, Alertmanager: alertmanager, - SQLStore: sqlstore, OrgGetter: orgGetter, + RuleStore: ruleStore, + MaintenanceStore: maintenanceStore, + SqlStore: sqlstore, } // create Manager diff --git a/pkg/query-service/app/server.go b/pkg/query-service/app/server.go index b81db76d1cce..2549810be594 100644 --- a/pkg/query-service/app/server.go +++ b/pkg/query-service/app/server.go @@ -8,6 +8,8 @@ import ( "net/http" _ "net/http/pprof" // http profiler + "github.com/SigNoz/signoz/pkg/ruler/rulestore/sqlrulestore" + "github.com/gorilla/handlers" "github.com/SigNoz/signoz/pkg/alertmanager" @@ -308,20 +310,24 @@ func makeRulesManager( querier querier.Querier, logger *slog.Logger, ) (*rules.Manager, error) { + ruleStore := sqlrulestore.NewRuleStore(sqlstore) + maintenanceStore := sqlrulestore.NewMaintenanceStore(sqlstore) // create manager opts managerOpts := &rules.ManagerOptions{ - TelemetryStore: telemetryStore, - Prometheus: prometheus, - Context: context.Background(), - Logger: zap.L(), - Reader: ch, - Querier: querier, - SLogger: logger, - Cache: cache, - EvalDelay: constants.GetEvalDelay(), - SQLStore: sqlstore, - OrgGetter: orgGetter, - Alertmanager: alertmanager, + TelemetryStore: telemetryStore, + Prometheus: prometheus, + Context: context.Background(), + Logger: zap.L(), + Reader: ch, + Querier: querier, + SLogger: logger, + Cache: cache, + EvalDelay: constants.GetEvalDelay(), + OrgGetter: orgGetter, + Alertmanager: alertmanager, + RuleStore: ruleStore, + MaintenanceStore: maintenanceStore, + SqlStore: sqlstore, } // create Manager diff --git a/pkg/query-service/rules/manager.go b/pkg/query-service/rules/manager.go index 5264b28f85ff..1a57414f13e2 100644 --- a/pkg/query-service/rules/manager.go +++ b/pkg/query-service/rules/manager.go @@ -22,7 +22,6 @@ import ( querierV5 "github.com/SigNoz/signoz/pkg/querier" "github.com/SigNoz/signoz/pkg/query-service/interfaces" "github.com/SigNoz/signoz/pkg/query-service/model" - "github.com/SigNoz/signoz/pkg/ruler/rulestore/sqlrulestore" "github.com/SigNoz/signoz/pkg/sqlstore" "github.com/SigNoz/signoz/pkg/telemetrystore" "github.com/SigNoz/signoz/pkg/types" @@ -98,8 +97,10 @@ type ManagerOptions struct { PrepareTaskFunc func(opts PrepareTaskOptions) (Task, error) PrepareTestRuleFunc func(opts PrepareTestRuleOptions) (int, *model.ApiError) Alertmanager alertmanager.Alertmanager - SQLStore sqlstore.SQLStore OrgGetter organization.Getter + RuleStore ruletypes.RuleStore + MaintenanceStore ruletypes.MaintenanceStore + SqlStore sqlstore.SQLStore } // The Manager manages recording and alerting rules. @@ -207,14 +208,12 @@ func defaultPrepareTaskFunc(opts PrepareTaskOptions) (Task, error) { // by calling the Run method. func NewManager(o *ManagerOptions) (*Manager, error) { o = defaultOptions(o) - ruleStore := sqlrulestore.NewRuleStore(o.SQLStore) - maintenanceStore := sqlrulestore.NewMaintenanceStore(o.SQLStore) m := &Manager{ tasks: map[string]Task{}, rules: map[string]Rule{}, - ruleStore: ruleStore, - maintenanceStore: maintenanceStore, + ruleStore: o.RuleStore, + maintenanceStore: o.MaintenanceStore, opts: o, block: make(chan struct{}), logger: o.Logger, @@ -223,8 +222,8 @@ func NewManager(o *ManagerOptions) (*Manager, error) { prepareTaskFunc: o.PrepareTaskFunc, prepareTestRuleFunc: o.PrepareTestRuleFunc, alertmanager: o.Alertmanager, - sqlstore: o.SQLStore, orgGetter: o.OrgGetter, + sqlstore: o.SqlStore, } return m, nil @@ -896,33 +895,37 @@ func (m *Manager) PatchRule(ctx context.Context, ruleStr string, id valuer.UUID) return nil, err } - // storedRule holds the current stored rule from DB - patchedRule := ruletypes.PostableRule{} - if err := json.Unmarshal([]byte(ruleStr), &patchedRule); err != nil { - zap.L().Error("failed to unmarshal stored rule with given id", zap.String("id", id.StringValue()), zap.Error(err)) + storedRule := ruletypes.PostableRule{} + if err := json.Unmarshal([]byte(storedJSON.Data), &storedRule); err != nil { + zap.L().Error("failed to unmarshal rule from db", zap.String("id", id.StringValue()), zap.Error(err)) + return nil, err + } + + if err := json.Unmarshal([]byte(ruleStr), &storedRule); err != nil { + zap.L().Error("failed to unmarshal patched rule with given id", zap.String("id", id.StringValue()), zap.Error(err)) return nil, err } // deploy or un-deploy task according to patched (new) rule state - if err := m.syncRuleStateWithTask(ctx, orgID, taskName, &patchedRule); err != nil { + if err := m.syncRuleStateWithTask(ctx, orgID, taskName, &storedRule); err != nil { zap.L().Error("failed to sync stored rule state with the task", zap.String("taskName", taskName), zap.Error(err)) return nil, err } - // prepare rule json to write to update db - patchedRuleBytes, err := json.Marshal(patchedRule) + newStoredJson, err := json.Marshal(&storedRule) if err != nil { + zap.L().Error("failed to marshal new stored rule with given id", zap.String("id", id.StringValue()), zap.Error(err)) return nil, err } now := time.Now() - storedJSON.Data = string(patchedRuleBytes) + storedJSON.Data = string(newStoredJson) storedJSON.UpdatedBy = claims.Email storedJSON.UpdatedAt = now err = m.ruleStore.EditRule(ctx, storedJSON, func(ctx context.Context) error { return nil }) if err != nil { - if err := m.syncRuleStateWithTask(ctx, orgID, taskName, &patchedRule); err != nil { + if err := m.syncRuleStateWithTask(ctx, orgID, taskName, &storedRule); err != nil { zap.L().Error("failed to restore rule after patch failure", zap.String("taskName", taskName), zap.Error(err)) } return nil, err @@ -931,7 +934,7 @@ func (m *Manager) PatchRule(ctx context.Context, ruleStr string, id valuer.UUID) // prepare http response response := ruletypes.GettableRule{ Id: id.StringValue(), - PostableRule: patchedRule, + PostableRule: storedRule, } // fetch state of rule from memory diff --git a/pkg/query-service/rules/manager_test.go b/pkg/query-service/rules/manager_test.go new file mode 100644 index 000000000000..795a918345b1 --- /dev/null +++ b/pkg/query-service/rules/manager_test.go @@ -0,0 +1,610 @@ +package rules + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + + "github.com/SigNoz/signoz/pkg/alertmanager" + "github.com/SigNoz/signoz/pkg/alertmanager/alertmanagerserver" + "github.com/SigNoz/signoz/pkg/alertmanager/signozalertmanager" + "github.com/SigNoz/signoz/pkg/instrumentation/instrumentationtest" + "github.com/SigNoz/signoz/pkg/modules/organization/implorganization" + "github.com/SigNoz/signoz/pkg/query-service/utils" + "github.com/SigNoz/signoz/pkg/ruler/rulestore/rulestoretest" + "github.com/SigNoz/signoz/pkg/sharder" + "github.com/SigNoz/signoz/pkg/sharder/noopsharder" + "github.com/SigNoz/signoz/pkg/types" + "github.com/SigNoz/signoz/pkg/types/alertmanagertypes" + "github.com/SigNoz/signoz/pkg/types/authtypes" + "github.com/SigNoz/signoz/pkg/types/ruletypes" + "github.com/SigNoz/signoz/pkg/valuer" +) + +func TestManager_PatchRule_PayloadVariations(t *testing.T) { + // Set up test claims and manager once for all test cases + claims := &authtypes.Claims{ + UserID: "550e8400-e29b-41d4-a716-446655440000", + Email: "test@example.com", + Role: "admin", + } + manager, mockSQLRuleStore, orgId := setupTestManager(t) + claims.OrgID = orgId + + testCases := []struct { + name string + originalData string + patchData string + expectedResult func(*ruletypes.GettableRule) bool + expectError bool + description string + }{ + { + name: "patch complete rule with task sync validation", + originalData: `{ + "schemaVersion":"v1", + "alert": "test-original-alert", + "alertType": "METRIC_BASED_ALERT", + "ruleType": "threshold_rule", + "evalWindow": "5m0s", + "condition": { + "compositeQuery": { + "queryType": "builder", + "panelType": "graph", + "queries": [ + { + "type": "builder_query", + "spec": { + "name": "A", + "signal": "metrics", + "disabled": false, + "aggregations": [ + { + "metricName": "container.cpu.time", + "timeAggregation": "rate", + "spaceAggregation": "sum" + } + ] + } + } + ] + } + }, + "labels": { + "severity": "warning" + }, + "disabled": false, + "preferredChannels": ["test-alerts"] + }`, + patchData: `{ + "alert": "test-patched-alert", + "labels": { + "severity": "critical" + } + }`, + expectedResult: func(result *ruletypes.GettableRule) bool { + return result.AlertName == "test-patched-alert" && + result.Labels["severity"] == "critical" && + result.Disabled == false + }, + expectError: false, + }, + { + name: "patch rule to disabled state", + originalData: `{ + "schemaVersion":"v2", + "alert": "test-disable-alert", + "alertType": "METRIC_BASED_ALERT", + "ruleType": "threshold_rule", + "evalWindow": "5m0s", + "condition": { + "thresholds": { + "kind": "basic", + "spec": [ + { + "name": "WARNING", + "target": 30, + "matchType": "1", + "op": "1", + "selectedQuery": "A", + "channels": ["test-alerts"] + } + ] + }, + "compositeQuery": { + "queryType": "builder", + "panelType": "graph", + "queries": [ + { + "type": "builder_query", + "spec": { + "name": "A", + "signal": "metrics", + "disabled": false, + "aggregations": [ + { + "metricName": "container.memory.usage", + "timeAggregation": "avg", + "spaceAggregation": "sum" + } + ] + } + } + ] + } + }, + "evaluation": { + "kind": "rolling", + "spec": { + "evalWindow": "5m", + "frequency": "1m" + } + }, + "labels": { + "severity": "warning" + }, + "disabled": false, + "preferredChannels": ["test-alerts"] + }`, + patchData: `{ + "disabled": true + }`, + expectedResult: func(result *ruletypes.GettableRule) bool { + return result.Disabled == true + }, + expectError: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ruleID := valuer.GenerateUUID() + existingRule := &ruletypes.Rule{ + Identifiable: types.Identifiable{ + ID: ruleID, + }, + TimeAuditable: types.TimeAuditable{ + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + UserAuditable: types.UserAuditable{ + CreatedBy: "creator@example.com", + UpdatedBy: "creator@example.com", + }, + Data: tc.originalData, + OrgID: claims.OrgID, + } + + mockSQLRuleStore.ExpectGetStoredRule(ruleID, existingRule) + mockSQLRuleStore.ExpectEditRule(existingRule) + + ctx := authtypes.NewContextWithClaims(context.Background(), *claims) + result, err := manager.PatchRule(ctx, tc.patchData, ruleID) + + assert.NoError(t, err) + assert.NotNil(t, result) + assert.Equal(t, ruleID.StringValue(), result.Id) + + if tc.expectedResult != nil { + assert.True(t, tc.expectedResult(result), "Expected result validation failed") + } + taskName := prepareTaskName(result.Id) + + if result.Disabled { + syncCompleted := waitForTaskSync(manager, taskName, false, 2*time.Second) + assert.True(t, syncCompleted, "Task synchronization should complete within timeout") + assert.Nil(t, findTaskByName(manager.RuleTasks(), taskName), "Task should be removed for disabled rule") + } else { + syncCompleted := waitForTaskSync(manager, taskName, true, 2*time.Second) + assert.True(t, syncCompleted, "Task synchronization should complete within timeout") + assert.NotNil(t, findTaskByName(manager.RuleTasks(), taskName), "Task should be created/updated for enabled rule") + assert.Greater(t, len(manager.Rules()), 0, "Rules should be updated in manager") + } + + assert.NoError(t, mockSQLRuleStore.AssertExpectations()) + }) + } +} + +func waitForTaskSync(manager *Manager, taskName string, expectedExists bool, timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + task := findTaskByName(manager.RuleTasks(), taskName) + exists := task != nil + + if exists == expectedExists { + return true + } + time.Sleep(10 * time.Millisecond) + } + return false +} + +// findTaskByName finds a task by name in the slice of tasks +func findTaskByName(tasks []Task, taskName string) Task { + for i := 0; i < len(tasks); i++ { + if tasks[i].Name() == taskName { + return tasks[i] + } + } + return nil +} + +func setupTestManager(t *testing.T) (*Manager, *rulestoretest.MockSQLRuleStore, string) { + settings := instrumentationtest.New().ToProviderSettings() + testDB := utils.NewQueryServiceDBForTests(t) + + err := utils.CreateTestOrg(t, testDB) + if err != nil { + t.Fatalf("Failed to create test org: %v", err) + } + testOrgID, err := utils.GetTestOrgId(testDB) + if err != nil { + t.Fatalf("Failed to get test org ID: %v", err) + } + + //will replace this with alertmanager mock + newConfig := alertmanagerserver.NewConfig() + defaultConfig, err := alertmanagertypes.NewDefaultConfig(newConfig.Global, newConfig.Route, testOrgID.StringValue()) + if err != nil { + t.Fatalf("Failed to create default alertmanager config: %v", err) + } + + _, err = testDB.BunDB().NewInsert(). + Model(defaultConfig.StoreableConfig()). + Exec(context.Background()) + if err != nil { + t.Fatalf("Failed to insert alertmanager config: %v", err) + } + + noopSharder, err := noopsharder.New(context.TODO(), settings, sharder.Config{}) + if err != nil { + t.Fatalf("Failed to create noop sharder: %v", err) + } + orgGetter := implorganization.NewGetter(implorganization.NewStore(testDB), noopSharder) + alertManager, err := signozalertmanager.New(context.TODO(), settings, alertmanager.Config{Provider: "signoz", Signoz: alertmanager.Signoz{PollInterval: 10 * time.Second, Config: alertmanagerserver.NewConfig()}}, testDB, orgGetter) + if err != nil { + t.Fatalf("Failed to create alert manager: %v", err) + } + mockSQLRuleStore := rulestoretest.NewMockSQLRuleStore() + + options := ManagerOptions{ + Context: context.Background(), + Logger: zap.L(), + SLogger: instrumentationtest.New().Logger(), + EvalDelay: time.Minute, + PrepareTaskFunc: defaultPrepareTaskFunc, + Alertmanager: alertManager, + OrgGetter: orgGetter, + RuleStore: mockSQLRuleStore, + } + + manager, err := NewManager(&options) + if err != nil { + t.Fatalf("Failed to create manager: %v", err) + } + + close(manager.block) + return manager, mockSQLRuleStore, testOrgID.StringValue() +} + +func TestCreateRule(t *testing.T) { + claims := &authtypes.Claims{ + Email: "test@example.com", + } + manager, mockSQLRuleStore, orgId := setupTestManager(t) + claims.OrgID = orgId + testCases := []struct { + name string + ruleStr string + }{ + { + name: "validate stored rule data structure", + ruleStr: `{ + "alert": "cpu usage", + "ruleType": "threshold_rule", + "evalWindow": "5m", + "frequency": "1m", + "condition": { + "compositeQuery": { + "queryType": "builder", + "builderQueries": { + "A": { + "expression": "A", + "disabled": false, + "dataSource": "metrics", + "aggregateOperator": "avg", + "aggregateAttribute": { + "key": "cpu_usage", + "type": "Gauge" + } + } + } + }, + "op": "1", + "target": 80, + "matchType": "1" + }, + "labels": { + "severity": "warning" + }, + "annotations": { + "summary": "High CPU usage detected" + }, + "preferredChannels": ["test-alerts"] + }`, + }, + { + name: "create complete v2 rule with thresholds", + ruleStr: `{ + "schemaVersion":"v2", + "state": "firing", + "alert": "test-multi-threshold-create", + "alertType": "METRIC_BASED_ALERT", + "ruleType": "threshold_rule", + "evalWindow": "5m0s", + "condition": { + "thresholds": { + "kind": "basic", + "spec": [ + { + "name": "CRITICAL", + "target": 0, + "matchType": "1", + "op": "1", + "selectedQuery": "A", + "channels": ["test-alerts"] + }, + { + "name": "WARNING", + "target": 0, + "matchType": "1", + "op": "1", + "selectedQuery": "A", + "channels": ["test-alerts"] + } + ] + }, + "compositeQuery": { + "queryType": "builder", + "panelType": "graph", + "queries": [ + { + "type": "builder_query", + "spec": { + "name": "A", + "signal": "metrics", + "disabled": false, + "aggregations": [ + { + "metricName": "container.cpu.time", + "timeAggregation": "rate", + "spaceAggregation": "sum" + } + ] + } + } + ] + } + }, + "evaluation": { + "kind": "rolling", + "spec": { + "evalWindow": "6m", + "frequency": "1m" + } + }, + "labels": { + "severity": "warning" + }, + "annotations": { + "description": "This alert is fired when the defined metric crosses the threshold", + "summary": "The rule threshold is set and the observed metric value is evaluated" + }, + "disabled": false, + "preferredChannels": ["#test-alerts-v2"], + "version": "v5" + }`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rule := &ruletypes.Rule{ + Identifiable: types.Identifiable{ + ID: valuer.GenerateUUID(), + }, + TimeAuditable: types.TimeAuditable{ + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + UserAuditable: types.UserAuditable{ + CreatedBy: claims.Email, + UpdatedBy: claims.Email, + }, + OrgID: claims.OrgID, + } + mockSQLRuleStore.ExpectCreateRule(rule) + + ctx := authtypes.NewContextWithClaims(context.Background(), *claims) + result, err := manager.CreateRule(ctx, tc.ruleStr) + + assert.NoError(t, err) + assert.NotNil(t, result) + assert.NotEmpty(t, result.Id, "Result should have a valid ID") + + // Wait for task creation with proper synchronization + taskName := prepareTaskName(result.Id) + syncCompleted := waitForTaskSync(manager, taskName, true, 2*time.Second) + assert.True(t, syncCompleted, "Task creation should complete within timeout") + assert.NotNil(t, findTaskByName(manager.RuleTasks(), taskName), "Task should be created with correct name") + assert.Greater(t, len(manager.Rules()), 0, "Rules should be added to manager") + + assert.NoError(t, mockSQLRuleStore.AssertExpectations()) + }) + } +} + +func TestEditRule(t *testing.T) { + // Set up test claims and manager once for all test cases + claims := &authtypes.Claims{ + Email: "test@example.com", + } + manager, mockSQLRuleStore, orgId := setupTestManager(t) + claims.OrgID = orgId + testCases := []struct { + name string + ruleStr string + }{ + { + name: "validate edit rule functionality", + ruleStr: `{ + "alert": "updated cpu usage", + "ruleType": "threshold_rule", + "evalWindow": "10m", + "frequency": "2m", + "condition": { + "compositeQuery": { + "queryType": "builder", + "builderQueries": { + "A": { + "expression": "A", + "disabled": false, + "dataSource": "metrics", + "aggregateOperator": "avg", + "aggregateAttribute": { + "key": "cpu_usage", + "type": "Gauge" + } + } + } + }, + "op": "1", + "target": 90, + "matchType": "1" + }, + "labels": { + "severity": "critical" + }, + "annotations": { + "summary": "Very high CPU usage detected" + }, + "preferredChannels": ["critical-alerts"] + }`, + }, + { + name: "edit complete v2 rule with thresholds", + ruleStr: `{ + "schemaVersion":"v2", + "state": "firing", + "alert": "test-multi-threshold-edit", + "alertType": "METRIC_BASED_ALERT", + "ruleType": "threshold_rule", + "evalWindow": "5m0s", + "condition": { + "thresholds": { + "kind": "basic", + "spec": [ + { + "name": "CRITICAL", + "target": 10, + "matchType": "1", + "op": "1", + "selectedQuery": "A", + "channels": ["test-alerts"] + }, + { + "name": "WARNING", + "target": 5, + "matchType": "1", + "op": "1", + "selectedQuery": "A", + "channels": ["test-alerts"] + } + ] + }, + "compositeQuery": { + "queryType": "builder", + "panelType": "graph", + "queries": [ + { + "type": "builder_query", + "spec": { + "name": "A", + "signal": "metrics", + "disabled": false, + "aggregations": [ + { + "metricName": "container.memory.usage", + "timeAggregation": "avg", + "spaceAggregation": "sum" + } + ] + } + } + ] + } + }, + "evaluation": { + "kind": "rolling", + "spec": { + "evalWindow": "8m", + "frequency": "2m" + } + }, + "labels": { + "severity": "critical" + }, + "annotations": { + "description": "This alert is fired when memory usage crosses the threshold", + "summary": "Memory usage threshold exceeded" + }, + "disabled": false, + "preferredChannels": ["#critical-alerts-v2"], + "version": "v5" + }`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ruleID := valuer.GenerateUUID() + + existingRule := &ruletypes.Rule{ + Identifiable: types.Identifiable{ + ID: ruleID, + }, + TimeAuditable: types.TimeAuditable{ + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + }, + UserAuditable: types.UserAuditable{ + CreatedBy: "creator@example.com", + UpdatedBy: "creator@example.com", + }, + Data: `{"alert": "original cpu usage", "disabled": false}`, + OrgID: claims.OrgID, + } + + mockSQLRuleStore.ExpectGetStoredRule(ruleID, existingRule) + mockSQLRuleStore.ExpectEditRule(existingRule) + + ctx := authtypes.NewContextWithClaims(context.Background(), *claims) + err := manager.EditRule(ctx, tc.ruleStr, ruleID) + + assert.NoError(t, err) + + // Wait for task update with proper synchronization + taskName := prepareTaskName(ruleID.StringValue()) + syncCompleted := waitForTaskSync(manager, taskName, true, 2*time.Second) + assert.True(t, syncCompleted, "Task update should complete within timeout") + assert.NotNil(t, findTaskByName(manager.RuleTasks(), taskName), "Task should be updated with correct name") + assert.Greater(t, len(manager.Rules()), 0, "Rules should be updated in manager") + + assert.NoError(t, mockSQLRuleStore.AssertExpectations()) + }) + } +} diff --git a/pkg/ruler/rulestore/rulestoretest/rule.go b/pkg/ruler/rulestore/rulestoretest/rule.go new file mode 100644 index 000000000000..01a2c3a87f8f --- /dev/null +++ b/pkg/ruler/rulestore/rulestoretest/rule.go @@ -0,0 +1,110 @@ +package rulestoretest + +import ( + "context" + "regexp" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/SigNoz/signoz/pkg/ruler/rulestore/sqlrulestore" + "github.com/SigNoz/signoz/pkg/sqlstore" + "github.com/SigNoz/signoz/pkg/sqlstore/sqlstoretest" + ruletypes "github.com/SigNoz/signoz/pkg/types/ruletypes" + "github.com/SigNoz/signoz/pkg/valuer" +) + +// MockSQLRuleStore is a mock RuleStore backed by sqlmock +type MockSQLRuleStore struct { + ruleStore ruletypes.RuleStore + mock sqlmock.Sqlmock +} + +// NewMockSQLRuleStore creates a new MockSQLRuleStore with sqlmock +func NewMockSQLRuleStore() *MockSQLRuleStore { + sqlStore := sqlstoretest.New(sqlstore.Config{Provider: "sqlite"}, sqlmock.QueryMatcherRegexp) + ruleStore := sqlrulestore.NewRuleStore(sqlStore) + + return &MockSQLRuleStore{ + ruleStore: ruleStore, + mock: sqlStore.Mock(), + } +} + +// Mock returns the sqlmock.Sqlmock instance for setting expectations +func (m *MockSQLRuleStore) Mock() sqlmock.Sqlmock { + return m.mock +} + +// CreateRule implements ruletypes.RuleStore - delegates to underlying ruleStore to trigger SQL +func (m *MockSQLRuleStore) CreateRule(ctx context.Context, rule *ruletypes.Rule, fn func(context.Context, valuer.UUID) error) (valuer.UUID, error) { + return m.ruleStore.CreateRule(ctx, rule, fn) +} + +// EditRule implements ruletypes.RuleStore - delegates to underlying ruleStore to trigger SQL +func (m *MockSQLRuleStore) EditRule(ctx context.Context, rule *ruletypes.Rule, fn func(context.Context) error) error { + return m.ruleStore.EditRule(ctx, rule, fn) +} + +// DeleteRule implements ruletypes.RuleStore - delegates to underlying ruleStore to trigger SQL +func (m *MockSQLRuleStore) DeleteRule(ctx context.Context, id valuer.UUID, fn func(context.Context) error) error { + return m.ruleStore.DeleteRule(ctx, id, fn) +} + +// GetStoredRule implements ruletypes.RuleStore - delegates to underlying ruleStore to trigger SQL +func (m *MockSQLRuleStore) GetStoredRule(ctx context.Context, id valuer.UUID) (*ruletypes.Rule, error) { + return m.ruleStore.GetStoredRule(ctx, id) +} + +// GetStoredRules implements ruletypes.RuleStore - delegates to underlying ruleStore to trigger SQL +func (m *MockSQLRuleStore) GetStoredRules(ctx context.Context, orgID string) ([]*ruletypes.Rule, error) { + return m.ruleStore.GetStoredRules(ctx, orgID) +} + +// ExpectCreateRule sets up SQL expectations for CreateRule operation +func (m *MockSQLRuleStore) ExpectCreateRule(rule *ruletypes.Rule) { + rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "created_by", "updated_by", "deleted", "data", "org_id"}). + AddRow(rule.ID, rule.CreatedAt, rule.UpdatedAt, rule.CreatedBy, rule.UpdatedBy, rule.Deleted, rule.Data, rule.OrgID) + expectedPattern := `INSERT INTO "rule" \(.+\) VALUES \(.+` + + regexp.QuoteMeta(rule.CreatedBy) + `.+` + + regexp.QuoteMeta(rule.OrgID) + `.+\) RETURNING` + m.mock.ExpectQuery(expectedPattern). + WillReturnRows(rows) +} + +// ExpectEditRule sets up SQL expectations for EditRule operation +func (m *MockSQLRuleStore) ExpectEditRule(rule *ruletypes.Rule) { + expectedPattern := `UPDATE "rule".+` + rule.UpdatedBy + `.+` + rule.OrgID + `.+WHERE \(id = '` + rule.ID.StringValue() + `'\)` + m.mock.ExpectExec(expectedPattern). + WillReturnResult(sqlmock.NewResult(1, 1)) +} + +// ExpectDeleteRule sets up SQL expectations for DeleteRule operation +func (m *MockSQLRuleStore) ExpectDeleteRule(ruleID valuer.UUID) { + expectedPattern := `DELETE FROM "rule".+WHERE \(id = '` + ruleID.StringValue() + `'\)` + m.mock.ExpectExec(expectedPattern). + WillReturnResult(sqlmock.NewResult(1, 1)) +} + +// ExpectGetStoredRule sets up SQL expectations for GetStoredRule operation +func (m *MockSQLRuleStore) ExpectGetStoredRule(ruleID valuer.UUID, rule *ruletypes.Rule) { + rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "created_by", "updated_by", "deleted", "data", "org_id"}). + AddRow(rule.ID, rule.CreatedAt, rule.UpdatedAt, rule.CreatedBy, rule.UpdatedBy, rule.Deleted, rule.Data, rule.OrgID) + expectedPattern := `SELECT (.+) FROM "rule".+WHERE \(id = '` + ruleID.StringValue() + `'\)` + m.mock.ExpectQuery(expectedPattern). + WillReturnRows(rows) +} + +// ExpectGetStoredRules sets up SQL expectations for GetStoredRules operation +func (m *MockSQLRuleStore) ExpectGetStoredRules(orgID string, rules []*ruletypes.Rule) { + rows := sqlmock.NewRows([]string{"id", "created_at", "updated_at", "created_by", "updated_by", "deleted", "data", "org_id"}) + for _, rule := range rules { + rows.AddRow(rule.ID, rule.CreatedAt, rule.UpdatedAt, rule.CreatedBy, rule.UpdatedBy, rule.Deleted, rule.Data, rule.OrgID) + } + expectedPattern := `SELECT (.+) FROM "rule".+WHERE \(.+org_id.+'` + orgID + `'\)` + m.mock.ExpectQuery(expectedPattern). + WillReturnRows(rows) +} + +// AssertExpectations asserts that all SQL expectations were met +func (m *MockSQLRuleStore) AssertExpectations() error { + return m.mock.ExpectationsWereMet() +} diff --git a/pkg/types/ruletypes/api_params.go b/pkg/types/ruletypes/api_params.go index 9285b070fbe6..ff5e8917b6d1 100644 --- a/pkg/types/ruletypes/api_params.go +++ b/pkg/types/ruletypes/api_params.go @@ -23,6 +23,10 @@ const ( AlertTypeExceptions AlertType = "EXCEPTIONS_BASED_ALERT" ) +const ( + DefaultSchemaVersion = "v1" +) + type RuleDataKind string const ( @@ -51,11 +55,16 @@ type PostableRule struct { Version string `json:"version,omitempty"` - Evaluation *EvaluationEnvelope `yaml:"evaluation,omitempty" json:"evaluation,omitempty"` + Evaluation *EvaluationEnvelope `yaml:"evaluation,omitempty" json:"evaluation,omitempty"` + SchemaVersion string `json:"schemaVersion,omitempty"` } func (r *PostableRule) processRuleDefaults() error { + if r.SchemaVersion == "" { + r.SchemaVersion = DefaultSchemaVersion + } + if r.EvalWindow == 0 { r.EvalWindow = Duration(5 * time.Minute) } @@ -79,7 +88,7 @@ func (r *PostableRule) processRuleDefaults() error { } } //added alerts v2 fields - if r.RuleCondition.Thresholds == nil { + if r.SchemaVersion == DefaultSchemaVersion { thresholdName := CriticalThresholdName if r.Labels != nil { if severity, ok := r.Labels["severity"]; ok { @@ -98,15 +107,33 @@ func (r *PostableRule) processRuleDefaults() error { }}, } r.RuleCondition.Thresholds = &thresholdData + r.Evaluation = &EvaluationEnvelope{RollingEvaluation, RollingWindow{EvalWindow: r.EvalWindow, Frequency: r.Frequency}} } } - if r.Evaluation == nil { - r.Evaluation = &EvaluationEnvelope{RollingEvaluation, RollingWindow{EvalWindow: r.EvalWindow, Frequency: r.Frequency}} - } return r.Validate() } +func (r *PostableRule) MarshalJSON() ([]byte, error) { + type Alias PostableRule + + switch r.SchemaVersion { + case DefaultSchemaVersion: + copyStruct := *r + aux := Alias(copyStruct) + if aux.RuleCondition != nil { + aux.RuleCondition.Thresholds = nil + } + aux.Evaluation = nil + aux.SchemaVersion = "" + return json.Marshal(aux) + default: + copyStruct := *r + aux := Alias(copyStruct) + return json.Marshal(aux) + } +} + func (r *PostableRule) UnmarshalJSON(bytes []byte) error { type Alias PostableRule aux := (*Alias)(r) @@ -263,3 +290,23 @@ type GettableRule struct { UpdatedAt *time.Time `json:"updateAt"` UpdatedBy *string `json:"updateBy"` } + +func (g *GettableRule) MarshalJSON() ([]byte, error) { + type Alias GettableRule + + switch g.SchemaVersion { + case DefaultSchemaVersion: + copyStruct := *g + aux := Alias(copyStruct) + if aux.RuleCondition != nil { + aux.RuleCondition.Thresholds = nil + } + aux.Evaluation = nil + aux.SchemaVersion = "" + return json.Marshal(aux) + default: + copyStruct := *g + aux := Alias(copyStruct) + return json.Marshal(aux) + } +} diff --git a/pkg/types/ruletypes/api_params_test.go b/pkg/types/ruletypes/api_params_test.go index 27ec5883714e..74d58fdb39a0 100644 --- a/pkg/types/ruletypes/api_params_test.go +++ b/pkg/types/ruletypes/api_params_test.go @@ -240,6 +240,338 @@ func TestParseIntoRule(t *testing.T) { } } +func TestParseIntoRuleSchemaVersioning(t *testing.T) { + tests := []struct { + name string + initRule PostableRule + content []byte + kind RuleDataKind + expectError bool + validate func(*testing.T, *PostableRule) + }{ + { + name: "schema v1 - threshold name from severity label", + initRule: PostableRule{}, + content: []byte(`{ + "alert": "SeverityLabelTest", + "schemaVersion": "v1", + "condition": { + "compositeQuery": { + "queryType": "builder", + "builderQueries": { + "A": { + "aggregateAttribute": { + "key": "cpu_usage" + } + } + }, + "unit": "percent" + }, + "target": 85.0, + "targetUnit": "%", + "matchType": "1", + "op": "1" + }, + "labels": { + "severity": "warning", + "team": "platform" + } + }`), + kind: RuleDataKindJson, + expectError: false, + validate: func(t *testing.T, rule *PostableRule) { + if rule.RuleCondition.Thresholds == nil { + t.Fatal("Expected Thresholds to be populated for v1") + } + + threshold := rule.RuleCondition.Thresholds + if threshold.Kind != BasicThresholdKind { + t.Errorf("Expected BasicThresholdKind, got %s", threshold.Kind) + } + + specs, ok := threshold.Spec.(BasicRuleThresholds) + if !ok { + t.Fatalf("Expected BasicRuleThresholds, got %T", threshold.Spec) + } + + if len(specs) != 1 { + t.Fatalf("Expected 1 threshold spec, got %d", len(specs)) + } + + spec := specs[0] + if spec.Name != "warning" { + t.Errorf("Expected threshold name 'warning' from severity label, got '%s'", spec.Name) + } + + // Verify all fields are copied from RuleCondition + if spec.RuleUnit != "percent" { + t.Errorf("Expected RuleUnit 'percent', got '%s'", spec.RuleUnit) + } + if spec.TargetUnit != "%" { + t.Errorf("Expected TargetUnit '%%', got '%s'", spec.TargetUnit) + } + if *spec.TargetValue != 85.0 { + t.Errorf("Expected TargetValue 85.0, got %v", *spec.TargetValue) + } + if spec.MatchType != rule.RuleCondition.MatchType { + t.Error("Expected MatchType to be copied from RuleCondition") + } + if spec.CompareOp != rule.RuleCondition.CompareOp { + t.Error("Expected CompareOp to be copied from RuleCondition") + } + + // Verify evaluation envelope is populated + if rule.Evaluation == nil { + t.Fatal("Expected Evaluation to be populated for v1") + } + if rule.Evaluation.Kind != RollingEvaluation { + t.Errorf("Expected RollingEvaluation, got %s", rule.Evaluation.Kind) + } + + // Verify evaluation window matches rule settings + if window, ok := rule.Evaluation.Spec.(RollingWindow); ok { + if window.EvalWindow != rule.EvalWindow { + t.Errorf("Expected Evaluation EvalWindow %v, got %v", rule.EvalWindow, window.EvalWindow) + } + if window.Frequency != rule.Frequency { + t.Errorf("Expected Evaluation Frequency %v, got %v", rule.Frequency, window.Frequency) + } + } else { + t.Errorf("Expected RollingWindow spec, got %T", rule.Evaluation.Spec) + } + }, + }, + { + name: "schema v1 - uses critical threshold when no labels", + initRule: PostableRule{}, + content: []byte(`{ + "alert": "NoLabelsTest", + "schemaVersion": "v1", + "condition": { + "compositeQuery": { + "queryType": "builder", + "builderQueries": { + "A": { + "aggregateAttribute": { + "key": "memory_usage" + } + } + } + }, + "target": 90.0, + "matchType": "1", + "op": "1" + } + }`), + kind: RuleDataKindJson, + expectError: false, + validate: func(t *testing.T, rule *PostableRule) { + if rule.RuleCondition.Thresholds == nil { + t.Fatal("Expected Thresholds to be populated") + } + + specs, ok := rule.RuleCondition.Thresholds.Spec.(BasicRuleThresholds) + if !ok { + t.Fatalf("Expected BasicRuleThresholds, got %T", rule.RuleCondition.Thresholds.Spec) + } + spec := specs[0] + // Should default to CriticalThresholdName when no severity label + if spec.Name != CriticalThresholdName { + t.Errorf("Expected threshold name '%s', got '%s'", CriticalThresholdName, spec.Name) + } + }, + }, + { + name: "schema v1 - overwrites existing thresholds and evaluation", + initRule: PostableRule{}, + content: []byte(`{ + "alert": "OverwriteTest", + "schemaVersion": "v1", + "condition": { + "compositeQuery": { + "queryType": "builder", + "builderQueries": { + "A": { + "aggregateAttribute": { + "key": "cpu_usage" + } + } + }, + "unit": "percent" + }, + "target": 80.0, + "targetUnit": "%", + "matchType": "1", + "op": "1", + "thresholds": { + "kind": "basic", + "spec": [{ + "name": "existing_threshold", + "target": 50.0, + "targetUnit": "MB", + "ruleUnit": "bytes", + "matchType": "1", + "op": "1" + }] + } + }, + "evaluation": { + "kind": "rolling", + "spec": { + "evalWindow": "10m", + "frequency": "2m" + } + }, + "frequency":"7m", + "evalWindow":"11m", + "labels": { + "severity": "critical" + } + }`), + kind: RuleDataKindJson, + expectError: false, + validate: func(t *testing.T, rule *PostableRule) { + if rule.RuleCondition.Thresholds == nil { + t.Fatal("Expected Thresholds to be populated") + } + + specs, ok := rule.RuleCondition.Thresholds.Spec.(BasicRuleThresholds) + if !ok { + t.Fatalf("Expected BasicRuleThresholds, got %T", rule.RuleCondition.Thresholds.Spec) + } + + if len(specs) != 1 { + t.Fatalf("Expected 1 threshold spec, got %d", len(specs)) + } + + spec := specs[0] + if spec.Name != "critical" { + t.Errorf("Expected threshold name 'critical' (overwritten), got '%s'", spec.Name) + } + + if *spec.TargetValue != 80.0 { + t.Errorf("Expected TargetValue 80.0 (overwritten), got %v", *spec.TargetValue) + } + if spec.TargetUnit != "%" { + t.Errorf("Expected TargetUnit '%%' (overwritten), got '%s'", spec.TargetUnit) + } + if spec.RuleUnit != "percent" { + t.Errorf("Expected RuleUnit 'percent' (overwritten), got '%s'", spec.RuleUnit) + } + + if rule.Evaluation == nil { + t.Fatal("Expected Evaluation to be populated") + } + if window, ok := rule.Evaluation.Spec.(RollingWindow); ok { + if window.EvalWindow != rule.EvalWindow { + t.Errorf("Expected Evaluation EvalWindow to be overwritten to %v, got %v", rule.EvalWindow, window.EvalWindow) + } + if window.Frequency != rule.Frequency { + t.Errorf("Expected Evaluation Frequency to be overwritten to %v, got %v", rule.Frequency, window.Frequency) + } + } else { + t.Errorf("Expected RollingWindow spec, got %T", rule.Evaluation.Spec) + } + }, + }, + { + name: "schema v2 - does not populate thresholds and evaluation", + initRule: PostableRule{}, + content: []byte(`{ + "alert": "V2Test", + "schemaVersion": "v2", + "condition": { + "compositeQuery": { + "queryType": "builder", + "builderQueries": { + "A": { + "aggregateAttribute": { + "key": "test_metric" + } + } + } + }, + "target": 100.0, + "matchType": "1", + "op": "1" + } + }`), + kind: RuleDataKindJson, + expectError: false, + validate: func(t *testing.T, rule *PostableRule) { + if rule.SchemaVersion != "v2" { + t.Errorf("Expected schemaVersion 'v2', got '%s'", rule.SchemaVersion) + } + + if rule.RuleCondition.Thresholds != nil { + t.Error("Expected Thresholds to be nil for v2") + } + if rule.Evaluation != nil { + t.Error("Expected Evaluation to be nil for v2") + } + + if rule.EvalWindow != Duration(5*time.Minute) { + t.Error("Expected default EvalWindow to be applied") + } + if rule.RuleType != RuleTypeThreshold { + t.Error("Expected RuleType to be auto-detected") + } + }, + }, + { + name: "default schema version - defaults to v1 behavior", + initRule: PostableRule{}, + content: []byte(`{ + "alert": "DefaultSchemaTest", + "condition": { + "compositeQuery": { + "queryType": "builder", + "builderQueries": { + "A": { + "aggregateAttribute": { + "key": "test_metric" + } + } + } + }, + "target": 75.0, + "matchType": "1", + "op": "1" + } + }`), + kind: RuleDataKindJson, + expectError: false, + validate: func(t *testing.T, rule *PostableRule) { + if rule.SchemaVersion != DefaultSchemaVersion { + t.Errorf("Expected default schemaVersion '%s', got '%s'", DefaultSchemaVersion, rule.SchemaVersion) + } + if rule.RuleCondition.Thresholds == nil { + t.Error("Expected Thresholds to be populated for default schema version") + } + if rule.Evaluation == nil { + t.Error("Expected Evaluation to be populated for default schema version") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := tt.initRule + err := json.Unmarshal(tt.content, &rule) + if tt.expectError && err == nil { + t.Errorf("Expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + if tt.validate != nil && err == nil { + tt.validate(t, &rule) + } + }) + } +} + func TestParseIntoRuleThresholdGeneration(t *testing.T) { content := []byte(`{ "alert": "TestThresholds", @@ -310,6 +642,7 @@ func TestParseIntoRuleThresholdGeneration(t *testing.T) { func TestParseIntoRuleMultipleThresholds(t *testing.T) { content := []byte(`{ + "schemaVersion": "v2", "alert": "MultiThresholdAlert", "ruleType": "threshold_rule", "condition": { diff --git a/pkg/types/ruletypes/threshold.go b/pkg/types/ruletypes/threshold.go index 4c47c790dc40..fba9765d5793 100644 --- a/pkg/types/ruletypes/threshold.go +++ b/pkg/types/ruletypes/threshold.go @@ -2,13 +2,14 @@ package ruletypes import ( "encoding/json" - "github.com/SigNoz/signoz/pkg/errors" - "github.com/SigNoz/signoz/pkg/query-service/converter" - "github.com/SigNoz/signoz/pkg/query-service/model/v3" - "github.com/SigNoz/signoz/pkg/query-service/utils/labels" - "github.com/SigNoz/signoz/pkg/valuer" "math" "sort" + + "github.com/SigNoz/signoz/pkg/errors" + "github.com/SigNoz/signoz/pkg/query-service/converter" + v3 "github.com/SigNoz/signoz/pkg/query-service/model/v3" + "github.com/SigNoz/signoz/pkg/query-service/utils/labels" + "github.com/SigNoz/signoz/pkg/valuer" ) type ThresholdKind struct {