merging caches + removing import cycle via type any

This commit is contained in:
mzack 2024-03-13 02:27:15 +01:00
parent 4b43bd0c65
commit 4aff6d7189
36 changed files with 455 additions and 432 deletions

View File

@ -15,6 +15,7 @@ import (
"github.com/projectdiscovery/nuclei/v3/internal/pdcp" "github.com/projectdiscovery/nuclei/v3/internal/pdcp"
"github.com/projectdiscovery/nuclei/v3/pkg/installer" "github.com/projectdiscovery/nuclei/v3/pkg/installer"
"github.com/projectdiscovery/nuclei/v3/pkg/pparser"
uncoverlib "github.com/projectdiscovery/uncover" uncoverlib "github.com/projectdiscovery/uncover"
pdcpauth "github.com/projectdiscovery/utils/auth/pdcp" pdcpauth "github.com/projectdiscovery/utils/auth/pdcp"
"github.com/projectdiscovery/utils/env" "github.com/projectdiscovery/utils/env"
@ -87,6 +88,7 @@ type Runner struct {
pdcpUploadErrMsg string pdcpUploadErrMsg string
//general purpose temporary directory //general purpose temporary directory
tmpDir string tmpDir string
parser pparser.Parser
} }
const pprofServerAddress = "127.0.0.1:8086" const pprofServerAddress = "127.0.0.1:8086"
@ -148,12 +150,18 @@ func New(options *types.Options) (*Runner, error) {
} }
} }
if options.Validate { parser, err := templates.New()
parsers.ShouldValidate = true if err != nil {
return nil, err
} }
if options.Validate {
parser.ShouldValidate = true
}
// TODO: refactor to pass options reference globally without cycles // TODO: refactor to pass options reference globally without cycles
parsers.NoStrictSyntax = options.NoStrictSyntax parser.NoStrictSyntax = options.NoStrictSyntax
runner.parser = parser
yaml.StrictSyntax = !options.NoStrictSyntax yaml.StrictSyntax = !options.NoStrictSyntax
if options.Headless { if options.Headless {
@ -431,6 +439,7 @@ func (r *Runner) RunEnumeration() error {
ExcludeMatchers: excludematchers.New(r.options.ExcludeMatchers), ExcludeMatchers: excludematchers.New(r.options.ExcludeMatchers),
InputHelper: input.NewHelper(), InputHelper: input.NewHelper(),
TemporaryDirectory: r.tmpDir, TemporaryDirectory: r.tmpDir,
Parser: r.parser,
} }
if r.options.ShouldUseHostError() { if r.options.ShouldUseHostError() {
@ -458,7 +467,7 @@ func (r *Runner) RunEnumeration() error {
if err := store.ValidateTemplates(); err != nil { if err := store.ValidateTemplates(); err != nil {
return err return err
} }
if stats.GetValue(parsers.SyntaxErrorStats) == 0 && stats.GetValue(parsers.SyntaxWarningStats) == 0 && stats.GetValue(parsers.RuntimeWarningsStats) == 0 { if stats.GetValue(templates.SyntaxErrorStats) == 0 && stats.GetValue(templates.SyntaxWarningStats) == 0 && stats.GetValue(templates.RuntimeWarningsStats) == 0 {
gologger.Info().Msgf("All templates validated successfully\n") gologger.Info().Msgf("All templates validated successfully\n")
} else { } else {
return errors.New("encountered errors while performing template validation") return errors.New("encountered errors while performing template validation")
@ -470,9 +479,6 @@ func (r *Runner) RunEnumeration() error {
disk.PrintDeprecatedPathsMsgIfApplicable(r.options.Silent) disk.PrintDeprecatedPathsMsgIfApplicable(r.options.Silent)
templates.PrintDeprecatedProtocolNameMsgIfApplicable(r.options.Silent, r.options.Verbose) templates.PrintDeprecatedProtocolNameMsgIfApplicable(r.options.Silent, r.options.Verbose)
// purge global caches primarily used for loading templates
config.DefaultConfig.PurgeGlobalCache()
// add the hosts from the metadata queries of loaded templates into input provider // add the hosts from the metadata queries of loaded templates into input provider
if r.options.Uncover && len(r.options.UncoverQuery) == 0 { if r.options.Uncover && len(r.options.UncoverQuery) == 0 {
uncoverOpts := &uncoverlib.Options{ uncoverOpts := &uncoverlib.Options{
@ -596,18 +602,18 @@ func (r *Runner) executeTemplatesInput(store *loader.Store, engine *core.Engine)
// displayExecutionInfo displays misc info about the nuclei engine execution // displayExecutionInfo displays misc info about the nuclei engine execution
func (r *Runner) displayExecutionInfo(store *loader.Store) { func (r *Runner) displayExecutionInfo(store *loader.Store) {
// Display stats for any loaded templates' syntax warnings or errors // Display stats for any loaded templates' syntax warnings or errors
stats.Display(parsers.SyntaxWarningStats) stats.Display(templates.SyntaxWarningStats)
stats.Display(parsers.SyntaxErrorStats) stats.Display(templates.SyntaxErrorStats)
stats.Display(parsers.RuntimeWarningsStats) stats.Display(templates.RuntimeWarningsStats)
if r.options.Verbose { if r.options.Verbose {
// only print these stats in verbose mode // only print these stats in verbose mode
stats.DisplayAsWarning(parsers.HeadlessFlagWarningStats) stats.DisplayAsWarning(templates.HeadlessFlagWarningStats)
stats.DisplayAsWarning(parsers.CodeFlagWarningStats) stats.DisplayAsWarning(templates.CodeFlagWarningStats)
stats.DisplayAsWarning(parsers.TemplatesExecutedStats) stats.DisplayAsWarning(templates.TemplatesExecutedStats)
} }
stats.DisplayAsWarning(parsers.UnsignedCodeWarning) stats.DisplayAsWarning(templates.UnsignedCodeWarning)
stats.ForceDisplayWarning(parsers.SkippedUnsignedStats) stats.ForceDisplayWarning(templates.SkippedUnsignedStats)
cfg := config.DefaultConfig cfg := config.DefaultConfig
@ -632,8 +638,8 @@ func (r *Runner) displayExecutionInfo(store *loader.Store) {
value := v.Load() value := v.Load()
if k == templates.Unsigned && value > 0 { if k == templates.Unsigned && value > 0 {
// adjust skipped unsigned templates via code or -dut flag // adjust skipped unsigned templates via code or -dut flag
value = value - uint64(stats.GetValue(parsers.SkippedUnsignedStats)) value = value - uint64(stats.GetValue(templates.SkippedUnsignedStats))
value = value - uint64(stats.GetValue(parsers.CodeFlagWarningStats)) value = value - uint64(stats.GetValue(templates.CodeFlagWarningStats))
} }
if value > 0 { if value > 0 {
if k != templates.Unsigned { if k != templates.Unsigned {

View File

@ -4,7 +4,7 @@ import (
"os" "os"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"github.com/projectdiscovery/nuclei/v3/pkg/model/types/severity" "github.com/projectdiscovery/nuclei/v3/pkg/model/types/severity"
"github.com/projectdiscovery/nuclei/v3/pkg/types" "github.com/projectdiscovery/nuclei/v3/pkg/types"
@ -15,15 +15,15 @@ func TestCreateReportingOptions(t *testing.T) {
options.ReportingConfig = "../../integration_tests/test-issue-tracker-config1.yaml" options.ReportingConfig = "../../integration_tests/test-issue-tracker-config1.yaml"
resultOptions, err := createReportingOptions(&options) resultOptions, err := createReportingOptions(&options)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, resultOptions.AllowList.Severities, severity.Severities{severity.High, severity.Critical}) require.Equal(t, resultOptions.AllowList.Severities, severity.Severities{severity.High, severity.Critical})
assert.Equal(t, resultOptions.DenyList.Severities, severity.Severities{severity.Low}) require.Equal(t, resultOptions.DenyList.Severities, severity.Severities{severity.Low})
options.ReportingConfig = "../../integration_tests/test-issue-tracker-config2.yaml" options.ReportingConfig = "../../integration_tests/test-issue-tracker-config2.yaml"
resultOptions2, err := createReportingOptions(&options) resultOptions2, err := createReportingOptions(&options)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, resultOptions2.AllowList.Severities, resultOptions.AllowList.Severities) require.Equal(t, resultOptions2.AllowList.Severities, resultOptions.AllowList.Severities)
assert.Equal(t, resultOptions2.DenyList.Severities, resultOptions.DenyList.Severities) require.Equal(t, resultOptions2.DenyList.Severities, resultOptions.DenyList.Severities)
} }
type TestStruct1 struct { type TestStruct1 struct {
@ -69,8 +69,8 @@ func TestWalkReflectStructAssignsEnvVars(t *testing.T) {
Walk(testStruct, expandEndVars) Walk(testStruct, expandEndVars)
assert.Equal(t, "value", testStruct.A) require.Equal(t, "value", testStruct.A)
assert.Equal(t, "value2", testStruct.Struct.B) require.Equal(t, "value2", testStruct.Struct.B)
} }
func TestWalkReflectStructHandlesDifferentTypes(t *testing.T) { func TestWalkReflectStructHandlesDifferentTypes(t *testing.T) {
@ -85,9 +85,9 @@ func TestWalkReflectStructHandlesDifferentTypes(t *testing.T) {
Walk(testStruct, expandEndVars) Walk(testStruct, expandEndVars)
assert.Equal(t, "value", testStruct.A) require.Equal(t, "value", testStruct.A)
assert.Equal(t, "2", testStruct.B) require.Equal(t, "2", testStruct.B)
assert.Equal(t, "true", testStruct.C) require.Equal(t, "true", testStruct.C)
} }
func TestWalkReflectStructEmpty(t *testing.T) { func TestWalkReflectStructEmpty(t *testing.T) {
@ -102,9 +102,9 @@ func TestWalkReflectStructEmpty(t *testing.T) {
Walk(testStruct, expandEndVars) Walk(testStruct, expandEndVars)
assert.Equal(t, "value", testStruct.A) require.Equal(t, "value", testStruct.A)
assert.Equal(t, "", testStruct.B) require.Equal(t, "", testStruct.B)
assert.Equal(t, "true", testStruct.C) require.Equal(t, "true", testStruct.C)
} }
func TestWalkReflectStructWithNoYamlTag(t *testing.T) { func TestWalkReflectStructWithNoYamlTag(t *testing.T) {
@ -119,9 +119,9 @@ func TestWalkReflectStructWithNoYamlTag(t *testing.T) {
os.Setenv("GITHUB_USER", "testuser") os.Setenv("GITHUB_USER", "testuser")
Walk(test, expandEndVars) Walk(test, expandEndVars)
assert.Equal(t, "testuser", test.A) require.Equal(t, "testuser", test.A)
assert.Equal(t, "testuser", test.B.B, test.B) require.Equal(t, "testuser", test.B.B, test.B)
assert.Equal(t, "$GITHUB_USER", test.C) require.Equal(t, "$GITHUB_USER", test.C)
} }
func TestWalkReflectStructHandlesNestedStructs(t *testing.T) { func TestWalkReflectStructHandlesNestedStructs(t *testing.T) {
@ -138,7 +138,7 @@ func TestWalkReflectStructHandlesNestedStructs(t *testing.T) {
Walk(testStruct, expandEndVars) Walk(testStruct, expandEndVars)
assert.Equal(t, "value", testStruct.A) require.Equal(t, "value", testStruct.A)
assert.Equal(t, "2", testStruct.Struct.B) require.Equal(t, "2", testStruct.Struct.B)
assert.Equal(t, "true", testStruct.Struct.C) require.Equal(t, "true", testStruct.Struct.C)
} }

View File

@ -11,18 +11,21 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/loader" "github.com/projectdiscovery/nuclei/v3/pkg/catalog/loader"
"github.com/projectdiscovery/gologger" "github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/nuclei/v3/pkg/parsers"
"github.com/projectdiscovery/nuclei/v3/pkg/templates" "github.com/projectdiscovery/nuclei/v3/pkg/templates"
"github.com/projectdiscovery/nuclei/v3/pkg/types" "github.com/projectdiscovery/nuclei/v3/pkg/types"
) )
// log available templates for verbose (-vv) // log available templates for verbose (-vv)
func (r *Runner) logAvailableTemplate(tplPath string) { func (r *Runner) logAvailableTemplate(tplPath string) {
t, err := parsers.ParseTemplate(tplPath, r.catalog) t, err := r.parser.ParseTemplate(tplPath, r.catalog)
tpl, ok := t.(*templates.Template)
if !ok {
panic("not a template")
}
if err != nil { if err != nil {
gologger.Error().Msgf("Could not parse file '%s': %s\n", tplPath, err) gologger.Error().Msgf("Could not parse file '%s': %s\n", tplPath, err)
} else { } else {
r.verboseTemplate(t) r.verboseTemplate(tpl)
} }
} }

View File

@ -5,7 +5,6 @@ import (
"time" "time"
"github.com/logrusorgru/aurora" "github.com/logrusorgru/aurora"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/loader" "github.com/projectdiscovery/nuclei/v3/pkg/catalog/loader"
"github.com/projectdiscovery/nuclei/v3/pkg/core" "github.com/projectdiscovery/nuclei/v3/pkg/core"
"github.com/projectdiscovery/nuclei/v3/pkg/core/inputs" "github.com/projectdiscovery/nuclei/v3/pkg/core/inputs"
@ -41,6 +40,7 @@ func createEphemeralObjects(base *NucleiEngine, opts *types.Options) (*unsafeOpt
HostErrorsCache: base.hostErrCache, HostErrorsCache: base.hostErrCache,
Colorizer: aurora.NewAurora(true), Colorizer: aurora.NewAurora(true),
ResumeCfg: types.NewResumeCfg(), ResumeCfg: types.NewResumeCfg(),
Parser: base.parser,
} }
if opts.RateLimitMinute > 0 { if opts.RateLimitMinute > 0 {
u.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(opts.RateLimitMinute), time.Minute) u.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(opts.RateLimitMinute), time.Minute)
@ -89,7 +89,6 @@ func (e *ThreadSafeNucleiEngine) GlobalLoadAllTemplates() error {
// GlobalResultCallback sets a callback function which will be called for each result // GlobalResultCallback sets a callback function which will be called for each result
func (e *ThreadSafeNucleiEngine) GlobalResultCallback(callback func(event *output.ResultEvent)) { func (e *ThreadSafeNucleiEngine) GlobalResultCallback(callback func(event *output.ResultEvent)) {
e.eng.resultCallbacks = []func(*output.ResultEvent){callback} e.eng.resultCallbacks = []func(*output.ResultEvent){callback}
config.DefaultConfig.PurgeGlobalCache()
} }
// ExecuteWithCallback executes templates on targets and calls callback on each result(only if results are found) // ExecuteWithCallback executes templates on targets and calls callback on each result(only if results are found)

View File

@ -71,6 +71,7 @@ type NucleiEngine struct {
mode engineMode mode engineMode
browserInstance *engine.Browser browserInstance *engine.Browser
httpClient *retryablehttp.Client httpClient *retryablehttp.Client
parser *templates.Parser
// unexported meta options // unexported meta options
opts *types.Options opts *types.Options

View File

@ -27,6 +27,7 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/http/httpclientpool" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/http/httpclientpool"
"github.com/projectdiscovery/nuclei/v3/pkg/reporting" "github.com/projectdiscovery/nuclei/v3/pkg/reporting"
"github.com/projectdiscovery/nuclei/v3/pkg/templates"
"github.com/projectdiscovery/nuclei/v3/pkg/testutils" "github.com/projectdiscovery/nuclei/v3/pkg/testutils"
"github.com/projectdiscovery/nuclei/v3/pkg/types" "github.com/projectdiscovery/nuclei/v3/pkg/types"
"github.com/projectdiscovery/ratelimit" "github.com/projectdiscovery/ratelimit"
@ -113,6 +114,12 @@ func (e *NucleiEngine) init() error {
e.httpClient = httpclient e.httpClient = httpclient
} }
if parser, err := templates.New(); err != nil {
return err
} else {
e.parser = parser
}
_ = protocolstate.Init(e.opts) _ = protocolstate.Init(e.opts)
_ = protocolinit.Init(e.opts) _ = protocolinit.Init(e.opts)
e.applyRequiredDefaults() e.applyRequiredDefaults()
@ -157,6 +164,7 @@ func (e *NucleiEngine) init() error {
Colorizer: aurora.NewAurora(true), Colorizer: aurora.NewAurora(true),
ResumeCfg: types.NewResumeCfg(), ResumeCfg: types.NewResumeCfg(),
Browser: e.browserInstance, Browser: e.browserInstance,
Parser: e.parser,
} }
if e.opts.RateLimitMinute > 0 { if e.opts.RateLimitMinute > 0 {

BIN
memogen

Binary file not shown.

View File

@ -45,9 +45,6 @@ type Config struct {
LatestNucleiTemplatesVersion string `json:"nuclei-templates-latest-version"` LatestNucleiTemplatesVersion string `json:"nuclei-templates-latest-version"`
LatestNucleiIgnoreHash string `json:"nuclei-latest-ignore-hash,omitempty"` LatestNucleiIgnoreHash string `json:"nuclei-latest-ignore-hash,omitempty"`
// Other AppLevel/Global Settings
registerdCaches []GlobalCache `json:"-"` // registered global caches
// internal / unexported fields // internal / unexported fields
disableUpdates bool `json:"-"` // disable updates both version check and template updates disableUpdates bool `json:"-"` // disable updates both version check and template updates
homeDir string `json:"-"` // User Home Directory homeDir string `json:"-"` // User Home Directory
@ -301,19 +298,6 @@ func (c *Config) WriteTemplatesIndex(index map[string]string) error {
return os.WriteFile(indexFile, buff.Bytes(), 0600) return os.WriteFile(indexFile, buff.Bytes(), 0600)
} }
// RegisterGlobalCache registers a global cache at app level
// and is available to be purged on demand
func (c *Config) RegisterGlobalCache(cache GlobalCache) {
c.registerdCaches = append(c.registerdCaches, cache)
}
// PurgeGlobalCache purges all registered global caches
func (c *Config) PurgeGlobalCache() {
for _, cache := range c.registerdCaches {
cache.Purge()
}
}
// getTemplatesConfigFilePath returns configDir/.templates-config.json file path // getTemplatesConfigFilePath returns configDir/.templates-config.json file path
func (c *Config) getTemplatesConfigFilePath() string { func (c *Config) getTemplatesConfigFilePath() string {
return filepath.Join(c.configDir, TemplateConfigFileName) return filepath.Join(c.configDir, TemplateConfigFileName)

View File

@ -13,13 +13,6 @@ import (
stringsutil "github.com/projectdiscovery/utils/strings" stringsutil "github.com/projectdiscovery/utils/strings"
) )
// GlobalCache are global cache that have global
// scope and are not purged but can be purged
// via config.DefaultConfig
type GlobalCache interface {
Purge()
}
var knownConfigFiles = []string{"cves.json", "contributors.json", "TEMPLATES-STATS.json"} var knownConfigFiles = []string{"cves.json", "contributors.json", "TEMPLATES-STATS.json"}
// TemplateFormat // TemplateFormat

View File

@ -12,11 +12,9 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/projectdiscovery/gologger" "github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog" "github.com/projectdiscovery/nuclei/v3/pkg/catalog"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
cfg "github.com/projectdiscovery/nuclei/v3/pkg/catalog/config" cfg "github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/loader/filter" "github.com/projectdiscovery/nuclei/v3/pkg/catalog/loader/filter"
"github.com/projectdiscovery/nuclei/v3/pkg/model/types/severity" "github.com/projectdiscovery/nuclei/v3/pkg/model/types/severity"
"github.com/projectdiscovery/nuclei/v3/pkg/parsers"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols" "github.com/projectdiscovery/nuclei/v3/pkg/protocols"
"github.com/projectdiscovery/nuclei/v3/pkg/templates" "github.com/projectdiscovery/nuclei/v3/pkg/templates"
templateTypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" templateTypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
@ -66,7 +64,7 @@ type Config struct {
// Store is a storage for loaded nuclei templates // Store is a storage for loaded nuclei templates
type Store struct { type Store struct {
tagFilter *filter.TagFilter tagFilter *templates.TagFilter
pathFilter *filter.PathFilter pathFilter *filter.PathFilter
config *Config config *Config
finalTemplates []string finalTemplates []string
@ -112,7 +110,7 @@ func NewConfig(options *types.Options, catalog catalog.Catalog, executerOpts pro
// New creates a new template store based on provided configuration // New creates a new template store based on provided configuration
func New(config *Config) (*Store, error) { func New(config *Config) (*Store, error) {
tagFilter, err := filter.New(&filter.Config{ tagFilter, err := templates.NewTagFilter(&templates.Config{
Tags: config.Tags, Tags: config.Tags,
ExcludeTags: config.ExcludeTags, ExcludeTags: config.ExcludeTags,
Authors: config.Authors, Authors: config.Authors,
@ -268,25 +266,27 @@ func (store *Store) ValidateTemplates() error {
filteredTemplatePaths := store.pathFilter.Match(templatePaths) filteredTemplatePaths := store.pathFilter.Match(templatePaths)
filteredWorkflowPaths := store.pathFilter.Match(workflowPaths) filteredWorkflowPaths := store.pathFilter.Match(workflowPaths)
if areTemplatesValid(store, filteredTemplatePaths) && areWorkflowsValid(store, filteredWorkflowPaths) { if store.areTemplatesValid(filteredTemplatePaths) && store.areWorkflowsValid(filteredWorkflowPaths) {
return nil return nil
} }
return errors.New("errors occurred during template validation") return errors.New("errors occurred during template validation")
} }
func areWorkflowsValid(store *Store, filteredWorkflowPaths map[string]struct{}) bool { func (store *Store) areWorkflowsValid(filteredWorkflowPaths map[string]struct{}) bool {
return areWorkflowOrTemplatesValid(store, filteredWorkflowPaths, true, func(templatePath string, tagFilter *filter.TagFilter) (bool, error) { return store.areWorkflowOrTemplatesValid(filteredWorkflowPaths, true, func(templatePath string, tagFilter *templates.TagFilter) (bool, error) {
return parsers.LoadWorkflow(templatePath, store.config.Catalog) return false, nil
// return store.config.ExecutorOptions.Parser.LoadWorkflow(templatePath, store.config.Catalog)
}) })
} }
func areTemplatesValid(store *Store, filteredTemplatePaths map[string]struct{}) bool { func (store *Store) areTemplatesValid(filteredTemplatePaths map[string]struct{}) bool {
return areWorkflowOrTemplatesValid(store, filteredTemplatePaths, false, func(templatePath string, tagFilter *filter.TagFilter) (bool, error) { return store.areWorkflowOrTemplatesValid(filteredTemplatePaths, false, func(templatePath string, tagFilter *templates.TagFilter) (bool, error) {
return parsers.LoadTemplate(templatePath, store.tagFilter, nil, store.config.Catalog) return false, nil
// return store.config.ExecutorOptions.Parser.LoadTemplate(templatePath, store.tagFilter, nil, store.config.Catalog)
}) })
} }
func areWorkflowOrTemplatesValid(store *Store, filteredTemplatePaths map[string]struct{}, isWorkflow bool, load func(templatePath string, tagFilter *filter.TagFilter) (bool, error)) bool { func (store *Store) areWorkflowOrTemplatesValid(filteredTemplatePaths map[string]struct{}, isWorkflow bool, load func(templatePath string, tagFilter *templates.TagFilter) (bool, error)) bool {
areTemplatesValid := true areTemplatesValid := true
for templatePath := range filteredTemplatePaths { for templatePath := range filteredTemplatePaths {
@ -339,7 +339,7 @@ func areWorkflowTemplatesValid(store *Store, workflows []*workflows.WorkflowTemp
} }
func isParsingError(message string, template string, err error) bool { func isParsingError(message string, template string, err error) bool {
if errors.Is(err, filter.ErrExcluded) { if errors.Is(err, templates.ErrExcluded) {
return false return false
} }
if errors.Is(err, templates.ErrCreateTemplateExecutor) { if errors.Is(err, templates.ErrCreateTemplateExecutor) {
@ -362,7 +362,7 @@ func (store *Store) LoadWorkflows(workflowsList []string) []*templates.Template
loadedWorkflows := make([]*templates.Template, 0, len(workflowPathMap)) loadedWorkflows := make([]*templates.Template, 0, len(workflowPathMap))
for workflowPath := range workflowPathMap { for workflowPath := range workflowPathMap {
loaded, err := parsers.LoadWorkflow(workflowPath, store.config.Catalog) loaded, err := store.config.ExecutorOptions.Parser.LoadWorkflow(workflowPath, store.config.Catalog)
if err != nil { if err != nil {
gologger.Warning().Msgf("Could not load workflow %s: %s\n", workflowPath, err) gologger.Warning().Msgf("Could not load workflow %s: %s\n", workflowPath, err)
} }
@ -387,38 +387,38 @@ func (store *Store) LoadTemplatesWithTags(templatesList, tags []string) []*templ
loadedTemplates := make([]*templates.Template, 0, len(templatePathMap)) loadedTemplates := make([]*templates.Template, 0, len(templatePathMap))
for templatePath := range templatePathMap { for templatePath := range templatePathMap {
loaded, err := parsers.LoadTemplate(templatePath, store.tagFilter, tags, store.config.Catalog) loaded, err := store.config.ExecutorOptions.Parser.LoadTemplate(templatePath, store.tagFilter, tags, store.config.Catalog)
if loaded || store.pathFilter.MatchIncluded(templatePath) { if loaded || store.pathFilter.MatchIncluded(templatePath) {
parsed, err := templates.Parse(templatePath, store.preprocessor, store.config.ExecutorOptions) parsed, err := templates.Parse(templatePath, store.preprocessor, store.config.ExecutorOptions)
if err != nil { if err != nil {
// exclude templates not compatible with offline matching from total runtime warning stats // exclude templates not compatible with offline matching from total runtime warning stats
if !errors.Is(err, templates.ErrIncompatibleWithOfflineMatching) { if !errors.Is(err, templates.ErrIncompatibleWithOfflineMatching) {
stats.Increment(parsers.RuntimeWarningsStats) stats.Increment(templates.RuntimeWarningsStats)
} }
gologger.Warning().Msgf("Could not parse template %s: %s\n", templatePath, err) gologger.Warning().Msgf("Could not parse template %s: %s\n", templatePath, err)
} else if parsed != nil { } else if parsed != nil {
if !parsed.Verified && store.config.ExecutorOptions.Options.DisableUnsignedTemplates { if !parsed.Verified && store.config.ExecutorOptions.Options.DisableUnsignedTemplates {
// skip unverified templates when prompted to // skip unverified templates when prompted to
stats.Increment(parsers.SkippedUnsignedStats) stats.Increment(templates.SkippedUnsignedStats)
continue continue
} }
if len(parsed.RequestsHeadless) > 0 && !store.config.ExecutorOptions.Options.Headless { if len(parsed.RequestsHeadless) > 0 && !store.config.ExecutorOptions.Options.Headless {
// donot include headless template in final list if headless flag is not set // donot include headless template in final list if headless flag is not set
stats.Increment(parsers.HeadlessFlagWarningStats) stats.Increment(templates.HeadlessFlagWarningStats)
if cfg.DefaultConfig.LogAllEvents { if cfg.DefaultConfig.LogAllEvents {
gologger.Print().Msgf("[%v] Headless flag is required for headless template '%s'.\n", aurora.Yellow("WRN").String(), templatePath) gologger.Print().Msgf("[%v] Headless flag is required for headless template '%s'.\n", aurora.Yellow("WRN").String(), templatePath)
} }
} else if len(parsed.RequestsCode) > 0 && !store.config.ExecutorOptions.Options.EnableCodeTemplates { } else if len(parsed.RequestsCode) > 0 && !store.config.ExecutorOptions.Options.EnableCodeTemplates {
// donot include 'Code' protocol custom template in final list if code flag is not set // donot include 'Code' protocol custom template in final list if code flag is not set
stats.Increment(parsers.CodeFlagWarningStats) stats.Increment(templates.CodeFlagWarningStats)
if cfg.DefaultConfig.LogAllEvents { if cfg.DefaultConfig.LogAllEvents {
gologger.Print().Msgf("[%v] Code flag is required for code protocol template '%s'.\n", aurora.Yellow("WRN").String(), templatePath) gologger.Print().Msgf("[%v] Code flag is required for code protocol template '%s'.\n", aurora.Yellow("WRN").String(), templatePath)
} }
} else if len(parsed.RequestsCode) > 0 && !parsed.Verified && len(parsed.Workflows) == 0 { } else if len(parsed.RequestsCode) > 0 && !parsed.Verified && len(parsed.Workflows) == 0 {
// donot include unverified 'Code' protocol custom template in final list // donot include unverified 'Code' protocol custom template in final list
stats.Increment(parsers.UnsignedCodeWarning) stats.Increment(templates.UnsignedCodeWarning)
// these will be skipped so increment skip counter // these will be skipped so increment skip counter
stats.Increment(parsers.SkippedUnsignedStats) stats.Increment(templates.SkippedUnsignedStats)
if cfg.DefaultConfig.LogAllEvents { if cfg.DefaultConfig.LogAllEvents {
gologger.Print().Msgf("[%v] Tampered/Unsigned template at %v.\n", aurora.Yellow("WRN").String(), templatePath) gologger.Print().Msgf("[%v] Tampered/Unsigned template at %v.\n", aurora.Yellow("WRN").String(), templatePath)
} }
@ -428,9 +428,9 @@ func (store *Store) LoadTemplatesWithTags(templatesList, tags []string) []*templ
} }
} }
if err != nil { if err != nil {
if strings.Contains(err.Error(), filter.ErrExcluded.Error()) { if strings.Contains(err.Error(), templates.ErrExcluded.Error()) {
stats.Increment(parsers.TemplatesExecutedStats) stats.Increment(templates.TemplatesExecutedStats)
if config.DefaultConfig.LogAllEvents { if cfg.DefaultConfig.LogAllEvents {
gologger.Print().Msgf("[%v] %v\n", aurora.Yellow("WRN").String(), err.Error()) gologger.Print().Msgf("[%v] %v\n", aurora.Yellow("WRN").String(), err.Error())
} }
continue continue

View File

@ -7,7 +7,7 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/model/types/severity" "github.com/projectdiscovery/nuclei/v3/pkg/model/types/severity"
"github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice" "github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -29,10 +29,10 @@ func TestInfoJsonMarshal(t *testing.T) {
} }
result, err := json.Marshal(&info) result, err := json.Marshal(&info)
assert.Nil(t, err) require.Nil(t, err)
expected := `{"name":"Test Template Name","author":["forgedhallpass","ice3man"],"tags":["cve","misc"],"description":"Test description","reference":"Reference1","severity":"high","metadata":{"array_key":["array_value1","array_value2"],"map_key":{"key1":"val1"},"string_key":"string_value"}}` expected := `{"name":"Test Template Name","author":["forgedhallpass","ice3man"],"tags":["cve","misc"],"description":"Test description","reference":"Reference1","severity":"high","metadata":{"array_key":["array_value1","array_value2"],"map_key":{"key1":"val1"},"string_key":"string_value"}}`
assert.Equal(t, expected, string(result)) require.Equal(t, expected, string(result))
} }
func TestInfoYamlMarshal(t *testing.T) { func TestInfoYamlMarshal(t *testing.T) {
@ -53,7 +53,7 @@ func TestInfoYamlMarshal(t *testing.T) {
} }
result, err := yaml.Marshal(&info) result, err := yaml.Marshal(&info)
assert.Nil(t, err) require.Nil(t, err)
expected := `name: Test Template Name expected := `name: Test Template Name
author: author:
@ -73,7 +73,7 @@ metadata:
key1: val1 key1: val1
string_key: string_value string_key: string_value
` `
assert.Equal(t, expected, string(result)) require.Equal(t, expected, string(result))
} }
func TestUnmarshal(t *testing.T) { func TestUnmarshal(t *testing.T) {
@ -94,13 +94,13 @@ func TestUnmarshal(t *testing.T) {
t.Helper() t.Helper()
info := Info{} info := Info{}
err := yaml.Unmarshal([]byte(yamlPayload), &info) err := yaml.Unmarshal([]byte(yamlPayload), &info)
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, info.Name, templateName) require.Equal(t, info.Name, templateName)
assert.Equal(t, info.Authors.ToSlice(), authors) require.Equal(t, info.Authors.ToSlice(), authors)
assert.Equal(t, info.Tags.ToSlice(), tags) require.Equal(t, info.Tags.ToSlice(), tags)
assert.Equal(t, info.SeverityHolder.Severity, severity.Critical) require.Equal(t, info.SeverityHolder.Severity, severity.Critical)
assert.Equal(t, info.Reference.ToSlice(), references) require.Equal(t, info.Reference.ToSlice(), references)
assert.Equal(t, info.Metadata, dynamicKeysMap) require.Equal(t, info.Metadata, dynamicKeysMap)
return info return info
} }
@ -133,5 +133,5 @@ func TestUnmarshal(t *testing.T) {
info1 := assertUnmarshalledTemplateInfo(t, yamlPayload1) info1 := assertUnmarshalledTemplateInfo(t, yamlPayload1)
info2 := assertUnmarshalledTemplateInfo(t, yamlPayload2) info2 := assertUnmarshalledTemplateInfo(t, yamlPayload2)
assert.Equal(t, info1, info2) require.Equal(t, info1, info2)
} }

View File

@ -5,7 +5,6 @@ import (
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -17,8 +16,8 @@ func TestYamlMarshal(t *testing.T) {
severity := Holder{Severity: High} severity := Holder{Severity: High}
marshalled, err := severity.MarshalYAML() marshalled, err := severity.MarshalYAML()
assert.Nil(t, err, "could not marshal yaml") require.Nil(t, err, "could not marshal yaml")
assert.Equal(t, "high", marshalled, "could not marshal severity correctly") require.Equal(t, "high", marshalled, "could not marshal severity correctly")
} }
func TestYamlUnmarshalFail(t *testing.T) { func TestYamlUnmarshalFail(t *testing.T) {
@ -27,7 +26,7 @@ func TestYamlUnmarshalFail(t *testing.T) {
func TestGetSupportedSeverities(t *testing.T) { func TestGetSupportedSeverities(t *testing.T) {
severities := GetSupportedSeverities() severities := GetSupportedSeverities()
assert.Equal(t, severities, Severities{Info, Low, Medium, High, Critical, Unknown}) require.Equal(t, severities, Severities{Info, Low, Medium, High, Critical, Unknown})
} }
func testUnmarshal(t *testing.T, unmarshaller func(data []byte, v interface{}) error, payloadCreator func(value string) string) { func testUnmarshal(t *testing.T, unmarshaller func(data []byte, v interface{}) error, payloadCreator func(value string) string) {
@ -43,15 +42,15 @@ func testUnmarshal(t *testing.T, unmarshaller func(data []byte, v interface{}) e
for _, payload := range payloads { // nolint:scopelint // false-positive for _, payload := range payloads { // nolint:scopelint // false-positive
t.Run(payload, func(t *testing.T) { t.Run(payload, func(t *testing.T) {
result := unmarshal(payload, unmarshaller) result := unmarshal(payload, unmarshaller)
assert.Equal(t, result.Severity, Info) require.Equal(t, result.Severity, Info)
assert.Equal(t, result.Severity.String(), "info") require.Equal(t, result.Severity.String(), "info")
}) })
} }
} }
func testUnmarshalFail(t *testing.T, unmarshaller func(data []byte, v interface{}) error, payloadCreator func(value string) string) { func testUnmarshalFail(t *testing.T, unmarshaller func(data []byte, v interface{}) error, payloadCreator func(value string) string) {
t.Helper() t.Helper()
assert.Panics(t, func() { unmarshal(payloadCreator("invalid"), unmarshaller) }) require.Panics(t, func() { unmarshal(payloadCreator("invalid"), unmarshaller) })
} }
func unmarshal(value string, unmarshaller func(data []byte, v interface{}) error) Holder { func unmarshal(value string, unmarshaller func(data []byte, v interface{}) error) Holder {

View File

@ -5,7 +5,6 @@ import (
"testing" "testing"
"github.com/Knetic/govaluate" "github.com/Knetic/govaluate"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -41,7 +40,7 @@ func testDslExpressionScenarios(t *testing.T, dslExpressions map[string]interfac
actualResult := evaluateExpression(t, dslExpression) actualResult := evaluateExpression(t, dslExpression)
if expectedResult != nil { if expectedResult != nil {
assert.Equal(t, expectedResult, actualResult) require.Equal(t, expectedResult, actualResult)
} }
fmt.Printf("%s: \t %v\n", dslExpression, actualResult) fmt.Printf("%s: \t %v\n", dslExpression, actualResult)

View File

@ -1,198 +0,0 @@
package parsers
import (
"encoding/json"
"errors"
"fmt"
"regexp"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/loader/filter"
"github.com/projectdiscovery/nuclei/v3/pkg/templates"
"github.com/projectdiscovery/nuclei/v3/pkg/templates/cache"
"github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
"github.com/projectdiscovery/nuclei/v3/pkg/utils"
"github.com/projectdiscovery/nuclei/v3/pkg/utils/stats"
errorutil "github.com/projectdiscovery/utils/errors"
"gopkg.in/yaml.v2"
)
var (
ErrMandatoryFieldMissingFmt = errorutil.NewWithFmt("mandatory '%s' field is missing")
ErrInvalidField = errorutil.NewWithFmt("invalid field format for '%s' (allowed format is %s)")
ErrWarningFieldMissing = errorutil.NewWithFmt("field '%s' is missing")
ErrCouldNotLoadTemplate = errorutil.NewWithFmt("Could not load template %s: %s")
ErrLoadedWithWarnings = errorutil.NewWithFmt("Loaded template %s: with syntax warning : %s")
)
// LoadTemplate returns true if the template is valid and matches the filtering criteria.
func LoadTemplate(templatePath string, tagFilter *filter.TagFilter, extraTags []string, catalog catalog.Catalog) (bool, error) {
template, templateParseError := ParseTemplate(templatePath, catalog)
if templateParseError != nil {
return false, ErrCouldNotLoadTemplate.Msgf(templatePath, templateParseError)
}
if len(template.Workflows) > 0 {
return false, nil
}
validationError := validateTemplateMandatoryFields(template)
if validationError != nil {
stats.Increment(SyntaxErrorStats)
return false, ErrCouldNotLoadTemplate.Msgf(templatePath, validationError)
}
ret, err := isTemplateInfoMetadataMatch(tagFilter, template, extraTags)
if err != nil {
return ret, ErrCouldNotLoadTemplate.Msgf(templatePath, err)
}
// if template loaded then check the template for optional fields to add warnings
if ret {
validationWarning := validateTemplateOptionalFields(template)
if validationWarning != nil {
stats.Increment(SyntaxWarningStats)
return ret, ErrCouldNotLoadTemplate.Msgf(templatePath, validationWarning)
}
}
return ret, nil
}
// LoadWorkflow returns true if the workflow is valid and matches the filtering criteria.
func LoadWorkflow(templatePath string, catalog catalog.Catalog) (bool, error) {
template, templateParseError := ParseTemplate(templatePath, catalog)
if templateParseError != nil {
return false, templateParseError
}
if len(template.Workflows) > 0 {
if validationError := validateTemplateMandatoryFields(template); validationError != nil {
stats.Increment(SyntaxErrorStats)
return false, validationError
}
return true, nil
}
return false, nil
}
func isTemplateInfoMetadataMatch(tagFilter *filter.TagFilter, template *templates.Template, extraTags []string) (bool, error) {
match, err := tagFilter.Match(template, extraTags)
if err == filter.ErrExcluded {
return false, filter.ErrExcluded
}
return match, err
}
// validateTemplateMandatoryFields validates the mandatory fields of a template
// return error from this function will cause hard fail and not proceed further
func validateTemplateMandatoryFields(template *templates.Template) error {
info := template.Info
var validateErrors []error
if utils.IsBlank(info.Name) {
validateErrors = append(validateErrors, ErrMandatoryFieldMissingFmt.Msgf("name"))
}
if info.Authors.IsEmpty() {
validateErrors = append(validateErrors, ErrMandatoryFieldMissingFmt.Msgf("author"))
}
if template.ID == "" {
validateErrors = append(validateErrors, ErrMandatoryFieldMissingFmt.Msgf("id"))
} else if !templateIDRegexp.MatchString(template.ID) {
validateErrors = append(validateErrors, ErrInvalidField.Msgf("id", templateIDRegexp.String()))
}
if len(validateErrors) > 0 {
return errors.Join(validateErrors...)
}
return nil
}
// validateTemplateOptionalFields validates the optional fields of a template
// return error from this function will throw a warning and proceed further
func validateTemplateOptionalFields(template *templates.Template) error {
info := template.Info
var warnings []error
if template.Type() != types.WorkflowProtocol && utils.IsBlank(info.SeverityHolder.Severity.String()) {
warnings = append(warnings, ErrWarningFieldMissing.Msgf("severity"))
}
if len(warnings) > 0 {
return errors.Join(warnings...)
}
return nil
}
var (
parsedTemplatesCache *cache.Templates
ShouldValidate bool
NoStrictSyntax bool
templateIDRegexp = regexp.MustCompile(`^([a-zA-Z0-9]+[-_])*[a-zA-Z0-9]+$`)
)
const (
SyntaxWarningStats = "syntax-warnings"
SyntaxErrorStats = "syntax-errors"
RuntimeWarningsStats = "runtime-warnings"
UnsignedCodeWarning = "unsigned-warnings"
HeadlessFlagWarningStats = "headless-flag-missing-warnings"
TemplatesExecutedStats = "templates-executed"
CodeFlagWarningStats = "code-flag-missing-warnings"
// Note: this is redefined in workflows.go to avoid circular dependency, so make sure to keep it in sync
SkippedUnsignedStats = "skipped-unsigned-stats" // tracks loading of unsigned templates
)
func init() {
parsedTemplatesCache = cache.New()
config.DefaultConfig.RegisterGlobalCache(parsedTemplatesCache)
stats.NewEntry(SyntaxWarningStats, "Found %d templates with syntax warning (use -validate flag for further examination)")
stats.NewEntry(SyntaxErrorStats, "Found %d templates with syntax error (use -validate flag for further examination)")
stats.NewEntry(RuntimeWarningsStats, "Found %d templates with runtime error (use -validate flag for further examination)")
stats.NewEntry(UnsignedCodeWarning, "Found %d unsigned or tampered code template (carefully examine before using it & use -sign flag to sign them)")
stats.NewEntry(HeadlessFlagWarningStats, "Excluded %d headless template[s] (disabled as default), use -headless option to run headless templates.")
stats.NewEntry(CodeFlagWarningStats, "Excluded %d code template[s] (disabled as default), use -code option to run code templates.")
stats.NewEntry(TemplatesExecutedStats, "Excluded %d template[s] with known weak matchers / tags excluded from default run using .nuclei-ignore")
stats.NewEntry(SkippedUnsignedStats, "Skipping %d unsigned template[s]")
}
// ParseTemplate parses a template and returns a *templates.Template structure
func ParseTemplate(templatePath string, catalog catalog.Catalog) (*templates.Template, error) {
if value, err := parsedTemplatesCache.Has(templatePath); value != nil {
return value.(*templates.Template), err
}
data, err := utils.ReadFromPathOrURL(templatePath, catalog)
if err != nil {
return nil, err
}
template := &templates.Template{}
switch config.GetTemplateFormatFromExt(templatePath) {
case config.JSON:
err = json.Unmarshal(data, template)
case config.YAML:
if NoStrictSyntax {
err = yaml.Unmarshal(data, template)
} else {
err = yaml.UnmarshalStrict(data, template)
}
default:
err = fmt.Errorf("failed to identify template format expected JSON or YAML but got %v", templatePath)
}
if err != nil {
return nil, err
}
parsedTemplatesCache.Store(templatePath, template, nil)
return template, nil
}

View File

@ -6,17 +6,18 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/loader/filter" "github.com/projectdiscovery/nuclei/v3/pkg/catalog/loader/filter"
"github.com/projectdiscovery/nuclei/v3/pkg/model" "github.com/projectdiscovery/nuclei/v3/pkg/model"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols" "github.com/projectdiscovery/nuclei/v3/pkg/protocols"
"github.com/projectdiscovery/nuclei/v3/pkg/templates"
) )
type workflowLoader struct { type workflowLoader struct {
pathFilter *filter.PathFilter pathFilter *filter.PathFilter
tagFilter *filter.TagFilter tagFilter *templates.TagFilter
options *protocols.ExecutorOptions options *protocols.ExecutorOptions
} }
// NewLoader returns a new workflow loader structure // NewLoader returns a new workflow loader structure
func NewLoader(options *protocols.ExecutorOptions) (model.WorkflowLoader, error) { func NewLoader(options *protocols.ExecutorOptions) (model.WorkflowLoader, error) {
tagFilter, err := filter.New(&filter.Config{ tagFilter, err := templates.NewTagFilter(&templates.Config{
Authors: options.Options.Authors, Authors: options.Options.Authors,
Tags: options.Options.Tags, Tags: options.Options.Tags,
ExcludeTags: options.Options.ExcludeTags, ExcludeTags: options.Options.ExcludeTags,
@ -50,7 +51,7 @@ func (w *workflowLoader) GetTemplatePathsByTags(templateTags []string) []string
loadedTemplates := make([]string, 0, len(templatePathMap)) loadedTemplates := make([]string, 0, len(templatePathMap))
for templatePath := range templatePathMap { for templatePath := range templatePathMap {
loaded, _ := LoadTemplate(templatePath, w.tagFilter, templateTags, w.options.Catalog) loaded, _ := w.options.Parser.LoadTemplate(templatePath, w.tagFilter, templateTags, w.options.Catalog)
if loaded { if loaded {
loadedTemplates = append(loadedTemplates, templatePath) loadedTemplates = append(loadedTemplates, templatePath)
} }
@ -67,7 +68,7 @@ func (w *workflowLoader) GetTemplatePaths(templatesList []string, noValidate boo
loadedTemplates := make([]string, 0, len(templatesPathMap)) loadedTemplates := make([]string, 0, len(templatesPathMap))
for templatePath := range templatesPathMap { for templatePath := range templatesPathMap {
matched, err := LoadTemplate(templatePath, w.tagFilter, nil, w.options.Catalog) matched, err := w.options.Parser.LoadTemplate(templatePath, w.tagFilter, nil, w.options.Catalog)
if err != nil && !matched { if err != nil && !matched {
gologger.Warning().Msg(err.Error()) gologger.Warning().Msg(err.Error())
} else if matched || noValidate { } else if matched || noValidate {

11
pkg/pparser/pparser.go Normal file
View File

@ -0,0 +1,11 @@
package pparser
import (
"github.com/projectdiscovery/nuclei/v3/pkg/catalog"
)
type Parser interface {
LoadTemplate(templatePath string, tagFilter any, extraTags []string, catalog catalog.Catalog) (bool, error)
ParseTemplate(templatePath string, catalog catalog.Catalog) (any, error)
LoadWorkflow(templatePath string, catalog catalog.Catalog) (bool, error)
}

View File

@ -4,9 +4,8 @@ import (
"encoding/hex" "encoding/hex"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/projectdiscovery/nuclei/v3/pkg/operators" "github.com/projectdiscovery/nuclei/v3/pkg/operators"
"github.com/stretchr/testify/require"
) )
const input = "abcdefghijklmnabcdefghijklmnabcdefghijklmnabcdefghijklmnabcdefghijklmnabcdefghijklmnabcdefghijklmnabcdefghijklmnabcdefghijklmn" const input = "abcdefghijklmnabcdefghijklmnabcdefghijklmnabcdefghijklmnabcdefghijklmnabcdefghijklmnabcdefghijklmnabcdefghijklmnabcdefghijklmn"
@ -24,8 +23,8 @@ func TestHexDumpHighlighting(t *testing.T) {
t.Run("Test highlighting when the snippet is wrapped", func(t *testing.T) { t.Run("Test highlighting when the snippet is wrapped", func(t *testing.T) {
result, err := toHighLightedHexDump(hex.Dump([]byte(input)), "defghij") result, err := toHighLightedHexDump(hex.Dump([]byte(input)), "defghij")
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, highlightedHexDumpResponse, result.String()) require.Equal(t, highlightedHexDumpResponse, result.String())
}) })
t.Run("Test highlight when the snippet contains separator character", func(t *testing.T) { t.Run("Test highlight when the snippet contains separator character", func(t *testing.T) {
@ -36,8 +35,8 @@ func TestHexDumpHighlighting(t *testing.T) {
"00000000 61 73 64 66 61 73 64 66 61 73 64 \x1b[32m61\x1b[0m \x1b[32m7c\x1b[0m \x1b[32m62\x1b[0m 61 73 |asdfasdfasd\x1b[32ma\x1b[0m\x1b[32m|\x1b[0m\x1b[32mb\x1b[0mas|\n" + "00000000 61 73 64 66 61 73 64 66 61 73 64 \x1b[32m61\x1b[0m \x1b[32m7c\x1b[0m \x1b[32m62\x1b[0m 61 73 |asdfasdfasd\x1b[32ma\x1b[0m\x1b[32m|\x1b[0m\x1b[32mb\x1b[0mas|\n" +
"00000010 64 66 61 64 73 64 66 73 7c |dfadsdfs||\n" "00000010 64 66 61 64 73 64 66 73 7c |dfadsdfs||\n"
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, expected, result.String()) require.Equal(t, expected, result.String())
}) })
} }
@ -59,7 +58,7 @@ func TestHighlight(t *testing.T) {
t.Run("Test highlighting when the snippet is wrapped", func(t *testing.T) { t.Run("Test highlighting when the snippet is wrapped", func(t *testing.T) {
result := Highlight(&operatorResult, hex.Dump([]byte(input)), false, true) result := Highlight(&operatorResult, hex.Dump([]byte(input)), false, true)
assert.Equal(t, multiSnippetHighlightHexDumpResponse, result) require.Equal(t, multiSnippetHighlightHexDumpResponse, result)
}) })
t.Run("Test highlighting without hexdump", func(t *testing.T) { t.Run("Test highlighting without hexdump", func(t *testing.T) {
@ -75,17 +74,17 @@ func TestHighlight(t *testing.T) {
"a\x1b[0m\x1b[32mb\x1b[0mc\x1b[32md\x1b[0m\x1b[32me\x1b[0m\x1b[32mf\x1b[0m\x1b[32mg\x1b[0m\x1b[32mh\x1b[0m\x1b[32mi\x1b[0m\x1b[32mj\x1b[0mklmn\x1b[32m" + "a\x1b[0m\x1b[32mb\x1b[0mc\x1b[32md\x1b[0m\x1b[32me\x1b[0m\x1b[32mf\x1b[0m\x1b[32mg\x1b[0m\x1b[32mh\x1b[0m\x1b[32mi\x1b[0m\x1b[32mj\x1b[0mklmn\x1b[32m" +
"a\x1b[0m\x1b[32mb\x1b[0mc\x1b[32md\x1b[0m\x1b[32me\x1b[0m\x1b[32mf\x1b[0m\x1b[32mg\x1b[0m\x1b[32mh\x1b[0m\x1b[32mi\x1b[0m\x1b[32mj\x1b[0mklmn" "a\x1b[0m\x1b[32mb\x1b[0mc\x1b[32md\x1b[0m\x1b[32me\x1b[0m\x1b[32mf\x1b[0m\x1b[32mg\x1b[0m\x1b[32mh\x1b[0m\x1b[32mi\x1b[0m\x1b[32mj\x1b[0mklmn"
print(result) print(result)
assert.Equal(t, expected, result) require.Equal(t, expected, result)
}) })
t.Run("Test the response is not modified if noColor is true", func(t *testing.T) { t.Run("Test the response is not modified if noColor is true", func(t *testing.T) {
result := Highlight(&operatorResult, input, true, false) result := Highlight(&operatorResult, input, true, false)
assert.Equal(t, input, result) require.Equal(t, input, result)
}) })
t.Run("Test the response is not modified if noColor is true", func(t *testing.T) { t.Run("Test the response is not modified if noColor is true", func(t *testing.T) {
result := Highlight(&operatorResult, hex.Dump([]byte(input)), true, true) result := Highlight(&operatorResult, hex.Dump([]byte(input)), true, true)
assert.Equal(t, hex.Dump([]byte(input)), result) require.Equal(t, hex.Dump([]byte(input)), result)
}) })
} }
@ -107,5 +106,5 @@ start ValueToMatch-2.1 end
"start \x1b[32mV\x1b[0m\x1b[32ma\x1b[0m\x1b[32ml\x1b[0m\x1b[32mu\x1b[0m\x1b[32me\x1b[0m\x1b[32mT\x1b[0m\x1b[32mo\x1b[0m\x1b[32mM\x1b[0m\x1b[32ma\x1b[0m\x1b[32mt\x1b[0m\x1b[32mc\x1b[0m\x1b[32mh\x1b[0m\x1b[32m-\x1b[0m\x1b[32m1\x1b[0m\x1b[32m.\x1b[0m\x1b[32m2\x1b[0m\x1b[32m.\x1b[0m\x1b[32m3\x1b[0m end\n" + "start \x1b[32mV\x1b[0m\x1b[32ma\x1b[0m\x1b[32ml\x1b[0m\x1b[32mu\x1b[0m\x1b[32me\x1b[0m\x1b[32mT\x1b[0m\x1b[32mo\x1b[0m\x1b[32mM\x1b[0m\x1b[32ma\x1b[0m\x1b[32mt\x1b[0m\x1b[32mc\x1b[0m\x1b[32mh\x1b[0m\x1b[32m-\x1b[0m\x1b[32m1\x1b[0m\x1b[32m.\x1b[0m\x1b[32m2\x1b[0m\x1b[32m.\x1b[0m\x1b[32m3\x1b[0m end\n" +
"start \x1b[32mV\x1b[0m\x1b[32ma\x1b[0m\x1b[32ml\x1b[0m\x1b[32mu\x1b[0m\x1b[32me\x1b[0m\x1b[32mT\x1b[0m\x1b[32mo\x1b[0m\x1b[32mM\x1b[0m\x1b[32ma\x1b[0m\x1b[32mt\x1b[0m\x1b[32mc\x1b[0m\x1b[32mh\x1b[0m\x1b[32m-\x1b[0m\x1b[32m2\x1b[0m\x1b[32m.\x1b[0m\x1b[32m1\x1b[0m end \n" "start \x1b[32mV\x1b[0m\x1b[32ma\x1b[0m\x1b[32ml\x1b[0m\x1b[32mu\x1b[0m\x1b[32me\x1b[0m\x1b[32mT\x1b[0m\x1b[32mo\x1b[0m\x1b[32mM\x1b[0m\x1b[32ma\x1b[0m\x1b[32mt\x1b[0m\x1b[32mc\x1b[0m\x1b[32mh\x1b[0m\x1b[32m-\x1b[0m\x1b[32m2\x1b[0m\x1b[32m.\x1b[0m\x1b[32m1\x1b[0m end \n"
result := Highlight(&operatorResult, input, false, false) result := Highlight(&operatorResult, input, false, false)
assert.Equal(t, expected, result) require.Equal(t, expected, result)
} }

View File

@ -4,7 +4,7 @@ import (
"net" "net"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestGetRandomIp(t *testing.T) { func TestGetRandomIp(t *testing.T) {
@ -110,15 +110,15 @@ func TestGetRandomIp(t *testing.T) {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
ip, err := GetRandomIPWithCidr(test.cidr...) ip, err := GetRandomIPWithCidr(test.cidr...)
if test.valid { if test.valid {
assert.NoError(t, err) require.NoError(t, err)
anyInRange := false anyInRange := false
for _, cidr := range test.cidr { for _, cidr := range test.cidr {
_, network, _ := net.ParseCIDR(cidr) _, network, _ := net.ParseCIDR(cidr)
anyInRange = anyInRange || network.Contains(ip) anyInRange = anyInRange || network.Contains(ip)
} }
assert.Truef(t, anyInRange, "the IP address returned %v is not in range of the provided CIDRs", ip) require.Truef(t, anyInRange, "the IP address returned %v is not in range of the provided CIDRs", ip)
} else { } else {
assert.Error(t, err, test.errorMsg) require.Error(t, err, test.errorMsg)
} }
}) })
} }

View File

@ -3,7 +3,7 @@ package replacer
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestReplacerReplace(t *testing.T) { func TestReplacerReplace(t *testing.T) {
@ -77,7 +77,7 @@ func TestReplacerReplace(t *testing.T) {
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.expected, Replace(test.template, test.values)) require.Equal(t, test.expected, Replace(test.template, test.values))
}) })
} }
} }
@ -135,7 +135,7 @@ func TestReplacerReplaceOne(t *testing.T) {
} }
for _, test := range tests { for _, test := range tests {
t.Run(test.name, func(t *testing.T) { t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.expected, ReplaceOne(test.template, test.key, test.value)) require.Equal(t, test.expected, ReplaceOne(test.template, test.key, test.value))
}) })
} }
} }

View File

@ -23,6 +23,7 @@ var (
"testcases/redis-pass-brute.yaml", "testcases/redis-pass-brute.yaml",
"testcases/ssh-server-fingerprint.yaml", "testcases/ssh-server-fingerprint.yaml",
} }
parser *templates.Parser
executerOpts protocols.ExecutorOptions executerOpts protocols.ExecutorOptions
) )
@ -40,6 +41,7 @@ func setup() {
Browser: nil, Browser: nil,
Catalog: disk.NewCatalog(config.DefaultConfig.TemplatesDirectory), Catalog: disk.NewCatalog(config.DefaultConfig.TemplatesDirectory),
RateLimiter: ratelimit.New(context.Background(), uint(options.RateLimit), time.Second), RateLimiter: ratelimit.New(context.Background(), uint(options.RateLimit), time.Second),
Parser: parser,
} }
workflowLoader, err := parsers.NewLoader(&executerOpts) workflowLoader, err := parsers.NewLoader(&executerOpts)
if err != nil { if err != nil {

View File

@ -18,6 +18,7 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/operators/extractors" "github.com/projectdiscovery/nuclei/v3/pkg/operators/extractors"
"github.com/projectdiscovery/nuclei/v3/pkg/operators/matchers" "github.com/projectdiscovery/nuclei/v3/pkg/operators/matchers"
"github.com/projectdiscovery/nuclei/v3/pkg/output" "github.com/projectdiscovery/nuclei/v3/pkg/output"
"github.com/projectdiscovery/nuclei/v3/pkg/pparser"
"github.com/projectdiscovery/nuclei/v3/pkg/progress" "github.com/projectdiscovery/nuclei/v3/pkg/progress"
"github.com/projectdiscovery/nuclei/v3/pkg/projectfile" "github.com/projectdiscovery/nuclei/v3/pkg/projectfile"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
@ -118,6 +119,7 @@ type ExecutorOptions struct {
OverrideThreadsCount PayloadThreadSetterCallback OverrideThreadsCount PayloadThreadSetterCallback
//TemporaryDirectory is the directory to store temporary files //TemporaryDirectory is the directory to store temporary files
TemporaryDirectory string TemporaryDirectory string
Parser pparser.Parser
} }
// GetThreadsForPayloadRequests returns the number of threads to use as default for // GetThreadsForPayloadRequests returns the number of threads to use as default for

View File

@ -4,7 +4,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func TestMarkDownHeaderCreation(t *testing.T) { func TestMarkDownHeaderCreation(t *testing.T) {
@ -21,7 +21,7 @@ func TestMarkDownHeaderCreation(t *testing.T) {
for _, currentTestCase := range testCases { for _, currentTestCase := range testCases {
t.Run(strings.Join(currentTestCase.headers, ","), func(t1 *testing.T) { t.Run(strings.Join(currentTestCase.headers, ","), func(t1 *testing.T) {
assert.Equal(t1, CreateTableHeader(currentTestCase.headers...), currentTestCase.expectedValue) require.Equal(t1, CreateTableHeader(currentTestCase.headers...), currentTestCase.expectedValue)
}) })
} }
} }
@ -34,8 +34,8 @@ func TestCreateTemplateInfoTableTooManyColumns(t *testing.T) {
{"h", "i"}, {"h", "i"},
}) })
assert.NotNil(t, err) require.NotNil(t, err)
assert.Empty(t, table) require.Empty(t, table)
} }
func TestCreateTemplateInfoTable1Column(t *testing.T) { func TestCreateTemplateInfoTable1Column(t *testing.T) {
@ -48,8 +48,8 @@ func TestCreateTemplateInfoTable1Column(t *testing.T) {
| c | | c |
` `
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, expected, table) require.Equal(t, expected, table)
} }
func TestCreateTemplateInfoTable2Columns(t *testing.T) { func TestCreateTemplateInfoTable2Columns(t *testing.T) {
@ -66,8 +66,8 @@ func TestCreateTemplateInfoTable2Columns(t *testing.T) {
| d | e | | d | e |
` `
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, expected, table) require.Equal(t, expected, table)
} }
func TestCreateTemplateInfoTable3Columns(t *testing.T) { func TestCreateTemplateInfoTable3Columns(t *testing.T) {
@ -86,6 +86,6 @@ func TestCreateTemplateInfoTable3Columns(t *testing.T) {
| h | i | | | h | i | |
` `
assert.Nil(t, err) require.Nil(t, err)
assert.Equal(t, expected, table) require.Equal(t, expected, table)
} }

View File

@ -4,12 +4,11 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/projectdiscovery/nuclei/v3/pkg/model" "github.com/projectdiscovery/nuclei/v3/pkg/model"
"github.com/projectdiscovery/nuclei/v3/pkg/model/types/severity" "github.com/projectdiscovery/nuclei/v3/pkg/model/types/severity"
"github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice" "github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice"
"github.com/projectdiscovery/nuclei/v3/pkg/reporting/exporters/markdown/util" "github.com/projectdiscovery/nuclei/v3/pkg/reporting/exporters/markdown/util"
"github.com/stretchr/testify/require"
) )
func TestToMarkdownTableString(t *testing.T) { func TestToMarkdownTableString(t *testing.T) {
@ -44,6 +43,6 @@ func TestToMarkdownTableString(t *testing.T) {
actualAttributeSlice := strings.Split(result, "\n") actualAttributeSlice := strings.Split(result, "\n")
dynamicAttributeIndex := len(actualAttributeSlice) - len(expectedDynamicAttributes) dynamicAttributeIndex := len(actualAttributeSlice) - len(expectedDynamicAttributes)
assert.Equal(t, strings.Split(expectedOrderedAttributes, "\n"), actualAttributeSlice[:dynamicAttributeIndex]) // the first part of the result is ordered require.Equal(t, strings.Split(expectedOrderedAttributes, "\n"), actualAttributeSlice[:dynamicAttributeIndex]) // the first part of the result is ordered
assert.ElementsMatch(t, expectedDynamicAttributes, actualAttributeSlice[dynamicAttributeIndex:]) // dynamic parameters are not ordered require.ElementsMatch(t, expectedDynamicAttributes, actualAttributeSlice[dynamicAttributeIndex:]) // dynamic parameters are not ordered
} }

View File

@ -1,21 +1,22 @@
package jira package jira
import ( import (
"github.com/stretchr/testify/assert"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/require"
) )
func TestLinkCreation(t *testing.T) { func TestLinkCreation(t *testing.T) {
jiraIntegration := &Integration{} jiraIntegration := &Integration{}
link := jiraIntegration.CreateLink("ProjectDiscovery", "https://projectdiscovery.io") link := jiraIntegration.CreateLink("ProjectDiscovery", "https://projectdiscovery.io")
assert.Equal(t, "[ProjectDiscovery|https://projectdiscovery.io]", link) require.Equal(t, "[ProjectDiscovery|https://projectdiscovery.io]", link)
} }
func TestHorizontalLineCreation(t *testing.T) { func TestHorizontalLineCreation(t *testing.T) {
jiraIntegration := &Integration{} jiraIntegration := &Integration{}
horizontalLine := jiraIntegration.CreateHorizontalLine() horizontalLine := jiraIntegration.CreateHorizontalLine()
assert.True(t, strings.Contains(horizontalLine, "----")) require.True(t, strings.Contains(horizontalLine, "----"))
} }
func TestTableCreation(t *testing.T) { func TestTableCreation(t *testing.T) {
@ -27,11 +28,11 @@ func TestTableCreation(t *testing.T) {
{"d", "e"}, {"d", "e"},
}) })
assert.Nil(t, err) require.Nil(t, err)
expected := `| key | value | expected := `| key | value |
| a | b | | a | b |
| c | | | c | |
| d | e | | d | e |
` `
assert.Equal(t, expected, table) require.Equal(t, expected, table)
} }

View File

@ -1,27 +1,27 @@
package cache package templates
import ( import (
mapsutil "github.com/projectdiscovery/utils/maps" mapsutil "github.com/projectdiscovery/utils/maps"
) )
// Templates is a cache for caching and storing templates for reuse. // Templates is a cache for caching and storing templates for reuse.
type Templates struct { type Cache struct {
items *mapsutil.SyncLockMap[string, parsedTemplateErrHolder] items *mapsutil.SyncLockMap[string, parsedTemplateErrHolder]
} }
// New returns a new templates cache // New returns a new templates cache
func New() *Templates { func NewCache() *Cache {
return &Templates{items: mapsutil.NewSyncLockMap[string, parsedTemplateErrHolder]()} return &Cache{items: mapsutil.NewSyncLockMap[string, parsedTemplateErrHolder]()}
} }
type parsedTemplateErrHolder struct { type parsedTemplateErrHolder struct {
template interface{} template *Template
err error err error
} }
// Has returns true if the cache has a template. The template // Has returns true if the cache has a template. The template
// is returned along with any errors if found. // is returned along with any errors if found.
func (t *Templates) Has(template string) (interface{}, error) { func (t *Cache) Has(template string) (*Template, error) {
value, ok := t.items.Get(template) value, ok := t.items.Get(template)
if !ok { if !ok {
return nil, nil return nil, nil
@ -30,11 +30,11 @@ func (t *Templates) Has(template string) (interface{}, error) {
} }
// Store stores a template with data and error // Store stores a template with data and error
func (t *Templates) Store(template string, data interface{}, err error) { func (t *Cache) Store(template string, data *Template, err error) {
_ = t.items.Set(template, parsedTemplateErrHolder{template: data, err: err}) _ = t.items.Set(template, parsedTemplateErrHolder{template: data, err: err})
} }
// Purge the cache // Purge the cache
func (t *Templates) Purge() { func (t *Cache) Purge() {
t.items.Clear() t.items.Clear()
} }

View File

@ -1,4 +1,4 @@
package cache package templates
import ( import (
"errors" "errors"
@ -8,14 +8,14 @@ import (
) )
func TestCache(t *testing.T) { func TestCache(t *testing.T) {
templates := New() templates := NewCache()
testErr := errors.New("test error") testErr := errors.New("test error")
data, err := templates.Has("test") data, err := templates.Has("test")
require.Nil(t, err, "invalid value for err") require.Nil(t, err, "invalid value for err")
require.Nil(t, data, "invalid value for data") require.Nil(t, data, "invalid value for data")
templates.Store("test", "data", testErr) templates.Store("test", &Template{}, testErr)
data, err = templates.Has("test") data, err = templates.Has("test")
require.Equal(t, testErr, err, "invalid value for err") require.Equal(t, testErr, err, "invalid value for err")
require.Equal(t, "data", data, "invalid value for data") require.Equal(t, "data", data, "invalid value for data")

View File

@ -18,7 +18,6 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/operators" "github.com/projectdiscovery/nuclei/v3/pkg/operators"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols" "github.com/projectdiscovery/nuclei/v3/pkg/protocols"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/offlinehttp" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/offlinehttp"
"github.com/projectdiscovery/nuclei/v3/pkg/templates/cache"
"github.com/projectdiscovery/nuclei/v3/pkg/templates/signer" "github.com/projectdiscovery/nuclei/v3/pkg/templates/signer"
"github.com/projectdiscovery/nuclei/v3/pkg/tmplexec" "github.com/projectdiscovery/nuclei/v3/pkg/tmplexec"
"github.com/projectdiscovery/nuclei/v3/pkg/utils" "github.com/projectdiscovery/nuclei/v3/pkg/utils"
@ -30,7 +29,6 @@ import (
var ( var (
ErrCreateTemplateExecutor = errors.New("cannot create template executer") ErrCreateTemplateExecutor = errors.New("cannot create template executer")
ErrIncompatibleWithOfflineMatching = errors.New("template can't be used for offline matching") ErrIncompatibleWithOfflineMatching = errors.New("template can't be used for offline matching")
parsedTemplatesCache *cache.Templates
// track how many templates are verfied and by which signer // track how many templates are verfied and by which signer
SignatureStats = map[string]*atomic.Uint64{} SignatureStats = map[string]*atomic.Uint64{}
) )
@ -39,23 +37,16 @@ const (
Unsigned = "unsigned" Unsigned = "unsigned"
) )
func init() {
parsedTemplatesCache = cache.New()
for _, verifier := range signer.DefaultTemplateVerifiers {
SignatureStats[verifier.Identifier()] = &atomic.Uint64{}
}
SignatureStats[Unsigned] = &atomic.Uint64{}
config.DefaultConfig.RegisterGlobalCache(parsedTemplatesCache)
}
// Parse parses a yaml request template file // Parse parses a yaml request template file
// TODO make sure reading from the disk the template parsing happens once: see parsers.ParseTemplate vs templates.Parse // TODO make sure reading from the disk the template parsing happens once: see parsers.ParseTemplate vs templates.Parse
//
//nolint:gocritic // this cannot be passed by pointer
func Parse(filePath string, preprocessor Preprocessor, options protocols.ExecutorOptions) (*Template, error) { func Parse(filePath string, preprocessor Preprocessor, options protocols.ExecutorOptions) (*Template, error) {
parser, ok := options.Parser.(*Parser)
if !ok {
panic("not a parser")
}
if !options.DoNotCache { if !options.DoNotCache {
if value, err := parsedTemplatesCache.Has(filePath); value != nil { if value, err := parser.compiledTemplatesCache.Has(filePath); value != nil {
return value.(*Template), err return value, err
} }
} }
@ -90,7 +81,7 @@ func Parse(filePath string, preprocessor Preprocessor, options protocols.Executo
} }
template.Path = filePath template.Path = filePath
if !options.DoNotCache { if !options.DoNotCache {
parsedTemplatesCache.Store(filePath, template, err) parser.compiledTemplatesCache.Store(filePath, template, err)
} }
return template, nil return template, nil
} }

View File

@ -3,23 +3,23 @@ package templates
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/require"
) )
func Test_appendAtSignToAuthors(t *testing.T) { func Test_appendAtSignToAuthors(t *testing.T) {
result := appendAtSignToAuthors([]string{"user1", "user2", "user3"}) result := appendAtSignToAuthors([]string{"user1", "user2", "user3"})
assert.Equal(t, result, "@user1,@user2,@user3") require.Equal(t, result, "@user1,@user2,@user3")
} }
func Test_appendAtSignToMissingAuthors(t *testing.T) { func Test_appendAtSignToMissingAuthors(t *testing.T) {
result := appendAtSignToAuthors([]string{}) result := appendAtSignToAuthors([]string{})
assert.Equal(t, result, "@none") require.Equal(t, result, "@none")
result = appendAtSignToAuthors(nil) result = appendAtSignToAuthors(nil)
assert.Equal(t, result, "@none") require.Equal(t, result, "@none")
} }
func Test_appendAtSignToOneAuthor(t *testing.T) { func Test_appendAtSignToOneAuthor(t *testing.T) {
result := appendAtSignToAuthors([]string{"user1"}) result := appendAtSignToAuthors([]string{"user1"})
assert.Equal(t, result, "@user1") require.Equal(t, result, "@user1")
} }

130
pkg/templates/parser.go Normal file
View File

@ -0,0 +1,130 @@
package templates
import (
"encoding/json"
"fmt"
"sync/atomic"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
"github.com/projectdiscovery/nuclei/v3/pkg/templates/signer"
"github.com/projectdiscovery/nuclei/v3/pkg/utils"
"github.com/projectdiscovery/nuclei/v3/pkg/utils/stats"
"gopkg.in/yaml.v2"
)
type Parser struct {
ShouldValidate bool
NoStrictSyntax bool
parsedTemplatesCache *Cache
compiledTemplatesCache *Cache
}
func New() (*Parser, error) {
p := &Parser{
parsedTemplatesCache: NewCache(),
compiledTemplatesCache: NewCache(),
}
for _, verifier := range signer.DefaultTemplateVerifiers {
SignatureStats[verifier.Identifier()] = &atomic.Uint64{}
}
SignatureStats[Unsigned] = &atomic.Uint64{}
return p, nil
}
// LoadTemplate returns true if the template is valid and matches the filtering criteria.
func (p *Parser) LoadTemplate(templatePath string, t any, extraTags []string, catalog catalog.Catalog) (bool, error) {
tagFilter, ok := t.(*TagFilter)
if !ok {
panic("not a *TagFilter")
}
t, templateParseError := p.ParseTemplate(templatePath, catalog)
if templateParseError != nil {
return false, ErrCouldNotLoadTemplate.Msgf(templatePath, templateParseError)
}
template, ok := t.(*Template)
if !ok {
panic("not a template")
}
if len(template.Workflows) > 0 {
return false, nil
}
validationError := validateTemplateMandatoryFields(template)
if validationError != nil {
stats.Increment(SyntaxErrorStats)
return false, ErrCouldNotLoadTemplate.Msgf(templatePath, validationError)
}
ret, err := isTemplateInfoMetadataMatch(tagFilter, template, extraTags)
if err != nil {
return ret, ErrCouldNotLoadTemplate.Msgf(templatePath, err)
}
// if template loaded then check the template for optional fields to add warnings
if ret {
validationWarning := validateTemplateOptionalFields(template)
if validationWarning != nil {
stats.Increment(SyntaxWarningStats)
return ret, ErrCouldNotLoadTemplate.Msgf(templatePath, validationWarning)
}
}
return ret, nil
}
// ParseTemplate parses a template and returns a *templates.Template structure
func (p *Parser) ParseTemplate(templatePath string, catalog catalog.Catalog) (any, error) {
if value, err := p.parsedTemplatesCache.Has(templatePath); value != nil {
return value, err
}
data, err := utils.ReadFromPathOrURL(templatePath, catalog)
if err != nil {
return nil, err
}
template := &Template{}
switch config.GetTemplateFormatFromExt(templatePath) {
case config.JSON:
err = json.Unmarshal(data, template)
case config.YAML:
if p.NoStrictSyntax {
err = yaml.Unmarshal(data, template)
} else {
err = yaml.UnmarshalStrict(data, template)
}
default:
err = fmt.Errorf("failed to identify template format expected JSON or YAML but got %v", templatePath)
}
if err != nil {
return nil, err
}
p.parsedTemplatesCache.Store(templatePath, template, nil)
return template, nil
}
// LoadWorkflow returns true if the workflow is valid and matches the filtering criteria.
func (p *Parser) LoadWorkflow(templatePath string, catalog catalog.Catalog) (bool, error) {
t, templateParseError := p.ParseTemplate(templatePath, catalog)
if templateParseError != nil {
return false, templateParseError
}
template, ok := t.(*Template)
if !ok {
panic("not a template")
}
if len(template.Workflows) > 0 {
if validationError := validateTemplateMandatoryFields(template); validationError != nil {
stats.Increment(SyntaxErrorStats)
return false, validationError
}
return true, nil
}
return false, nil
}

View File

@ -0,0 +1,7 @@
package templates
import "regexp"
var (
ReTemplateID = regexp.MustCompile(`^([a-zA-Z0-9]+[-_])*[a-zA-Z0-9]+$`)
)

View File

@ -0,0 +1,13 @@
package templates
import (
errorutil "github.com/projectdiscovery/utils/errors"
)
var (
ErrMandatoryFieldMissingFmt = errorutil.NewWithFmt("mandatory '%s' field is missing")
ErrInvalidField = errorutil.NewWithFmt("invalid field format for '%s' (allowed format is %s)")
ErrWarningFieldMissing = errorutil.NewWithFmt("field '%s' is missing")
ErrCouldNotLoadTemplate = errorutil.NewWithFmt("Could not load template %s: %s")
ErrLoadedWithWarnings = errorutil.NewWithFmt("Loaded template %s: with syntax warning : %s")
)

View File

@ -0,0 +1,13 @@
package templates
const (
SyntaxWarningStats = "syntax-warnings"
SyntaxErrorStats = "syntax-errors"
RuntimeWarningsStats = "runtime-warnings"
UnsignedCodeWarning = "unsigned-warnings"
HeadlessFlagWarningStats = "headless-flag-missing-warnings"
TemplatesExecutedStats = "templates-executed"
CodeFlagWarningStats = "code-flag-missing-warnings"
// Note: this is redefined in workflows.go to avoid circular dependency, so make sure to keep it in sync
SkippedUnsignedStatsTODO = "skipped-unsigned-stats" // tracks loading of unsigned templates
)

View File

@ -1,4 +1,4 @@
package parsers package templates
import ( import (
"errors" "errors"
@ -6,31 +6,29 @@ import (
"testing" "testing"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/disk" "github.com/projectdiscovery/nuclei/v3/pkg/catalog/disk"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/loader/filter"
"github.com/projectdiscovery/nuclei/v3/pkg/model" "github.com/projectdiscovery/nuclei/v3/pkg/model"
"github.com/projectdiscovery/nuclei/v3/pkg/model/types/severity" "github.com/projectdiscovery/nuclei/v3/pkg/model/types/severity"
"github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice" "github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice"
"github.com/projectdiscovery/nuclei/v3/pkg/templates"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestLoadTemplate(t *testing.T) { func TestLoadTemplate(t *testing.T) {
catalog := disk.NewCatalog("") catalog := disk.NewCatalog("")
origTemplatesCache := parsedTemplatesCache p, err := New()
defer func() { parsedTemplatesCache = origTemplatesCache }() require.Nil(t, err)
tt := []struct { tt := []struct {
name string name string
template *templates.Template template *Template
templateErr error templateErr error
filter filter.Config filter Config
expectedErr error expectedErr error
isValid bool isValid bool
}{ }{
{ {
name: "valid", name: "valid",
template: &templates.Template{ template: &Template{
ID: "CVE-2021-27330", ID: "CVE-2021-27330",
Info: model.Info{ Info: model.Info{
Name: "Valid template", Name: "Valid template",
@ -42,24 +40,24 @@ func TestLoadTemplate(t *testing.T) {
}, },
{ {
name: "emptyTemplate", name: "emptyTemplate",
template: &templates.Template{}, template: &Template{},
isValid: false, isValid: false,
expectedErr: errors.New("mandatory 'name' field is missing, mandatory 'author' field is missing, mandatory 'id' field is missing"), expectedErr: errors.New("mandatory 'name' field is missing\nmandatory 'author' field is missing\nmandatory 'id' field is missing"),
}, },
{ {
name: "emptyNameWithInvalidID", name: "emptyNameWithInvalidID",
template: &templates.Template{ template: &Template{
ID: "invalid id", ID: "invalid id",
Info: model.Info{ Info: model.Info{
Authors: stringslice.StringSlice{Value: "Author"}, Authors: stringslice.StringSlice{Value: "Author"},
SeverityHolder: severity.Holder{Severity: severity.Medium}, SeverityHolder: severity.Holder{Severity: severity.Medium},
}, },
}, },
expectedErr: errors.New("mandatory 'name' field is missing, invalid field format for 'id' (allowed format is ^([a-zA-Z0-9]+[-_])*[a-zA-Z0-9]+$)"), expectedErr: errors.New("mandatory 'name' field is missing\ninvalid field format for 'id' (allowed format is ^([a-zA-Z0-9]+[-_])*[a-zA-Z0-9]+$)"),
}, },
{ {
name: "emptySeverity", name: "emptySeverity",
template: &templates.Template{ template: &Template{
ID: "CVE-2021-27330", ID: "CVE-2021-27330",
Info: model.Info{ Info: model.Info{
Name: "Valid template", Name: "Valid template",
@ -71,7 +69,7 @@ func TestLoadTemplate(t *testing.T) {
}, },
{ {
name: "template-without-severity-with-correct-filter-id", name: "template-without-severity-with-correct-filter-id",
template: &templates.Template{ template: &Template{
ID: "CVE-2021-27330", ID: "CVE-2021-27330",
Info: model.Info{ Info: model.Info{
Name: "Valid template", Name: "Valid template",
@ -81,11 +79,11 @@ func TestLoadTemplate(t *testing.T) {
// should be error because the template is loaded // should be error because the template is loaded
expectedErr: errors.New("field 'severity' is missing"), expectedErr: errors.New("field 'severity' is missing"),
isValid: true, isValid: true,
filter: filter.Config{IncludeIds: []string{"CVE-2021-27330"}}, filter: Config{IncludeIds: []string{"CVE-2021-27330"}},
}, },
{ {
name: "template-without-severity-with-diff-filter-id", name: "template-without-severity-with-diff-filter-id",
template: &templates.Template{ template: &Template{
ID: "CVE-2021-27330", ID: "CVE-2021-27330",
Info: model.Info{ Info: model.Info{
Name: "Valid template", Name: "Valid template",
@ -93,7 +91,7 @@ func TestLoadTemplate(t *testing.T) {
}, },
}, },
isValid: false, isValid: false,
filter: filter.Config{IncludeIds: []string{"another-id"}}, filter: Config{IncludeIds: []string{"another-id"}},
// no error because the template is not loaded // no error because the template is not loaded
expectedErr: nil, expectedErr: nil,
}, },
@ -101,11 +99,11 @@ func TestLoadTemplate(t *testing.T) {
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
parsedTemplatesCache.Store(tc.name, tc.template, tc.templateErr) p.parsedTemplatesCache.Store(tc.name, tc.template, tc.templateErr)
tagFilter, err := filter.New(&tc.filter) tagFilter, err := NewTagFilter(&tc.filter)
require.Nil(t, err) require.Nil(t, err)
success, err := LoadTemplate(tc.name, tagFilter, nil, catalog) success, err := p.LoadTemplate(tc.name, tagFilter, nil, catalog)
if tc.expectedErr == nil { if tc.expectedErr == nil {
require.NoError(t, err) require.NoError(t, err)
} else { } else {
@ -135,7 +133,7 @@ func TestLoadTemplate(t *testing.T) {
for i, tc := range tt { for i, tc := range tt {
name := fmt.Sprintf("regexp%d", i) name := fmt.Sprintf("regexp%d", i)
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
template := &templates.Template{ template := &Template{
ID: tc.id, ID: tc.id,
Info: model.Info{ Info: model.Info{
Name: "Valid template", Name: "Valid template",
@ -143,11 +141,11 @@ func TestLoadTemplate(t *testing.T) {
SeverityHolder: severity.Holder{Severity: severity.Medium}, SeverityHolder: severity.Holder{Severity: severity.Medium},
}, },
} }
parsedTemplatesCache.Store(name, template, nil) p.parsedTemplatesCache.Store(name, template, nil)
tagFilter, err := filter.New(&filter.Config{}) tagFilter, err := NewTagFilter(&Config{})
require.Nil(t, err) require.Nil(t, err)
success, err := LoadTemplate(name, tagFilter, nil, catalog) success, err := p.LoadTemplate(name, tagFilter, nil, catalog)
if tc.success { if tc.success {
require.NoError(t, err) require.NoError(t, err)
require.True(t, success) require.True(t, success)

View File

@ -0,0 +1,64 @@
package templates
import (
"errors"
"github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
"github.com/projectdiscovery/nuclei/v3/pkg/utils"
)
// validateTemplateMandatoryFields validates the mandatory fields of a template
// return error from this function will cause hard fail and not proceed further
func validateTemplateMandatoryFields(template *Template) error {
info := template.Info
var validateErrors []error
if utils.IsBlank(info.Name) {
validateErrors = append(validateErrors, ErrMandatoryFieldMissingFmt.Msgf("name"))
}
if info.Authors.IsEmpty() {
validateErrors = append(validateErrors, ErrMandatoryFieldMissingFmt.Msgf("author"))
}
if template.ID == "" {
validateErrors = append(validateErrors, ErrMandatoryFieldMissingFmt.Msgf("id"))
} else if !ReTemplateID.MatchString(template.ID) {
validateErrors = append(validateErrors, ErrInvalidField.Msgf("id", ReTemplateID.String()))
}
if len(validateErrors) > 0 {
return errors.Join(validateErrors...)
}
return nil
}
func isTemplateInfoMetadataMatch(tagFilter *TagFilter, template *Template, extraTags []string) (bool, error) {
match, err := tagFilter.Match(template, extraTags)
if err == ErrExcluded {
return false, ErrExcluded
}
return match, err
}
// validateTemplateOptionalFields validates the optional fields of a template
// return error from this function will throw a warning and proceed further
func validateTemplateOptionalFields(template *Template) error {
info := template.Info
var warnings []error
if template.Type() != types.WorkflowProtocol && utils.IsBlank(info.SeverityHolder.Severity.String()) {
warnings = append(warnings, ErrWarningFieldMissing.Msgf("severity"))
}
if len(warnings) > 0 {
return errors.Join(warnings...)
}
return nil
}

View File

@ -1,4 +1,4 @@
package filter package templates
import ( import (
"bufio" "bufio"
@ -14,7 +14,6 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/operators/common/dsl" "github.com/projectdiscovery/nuclei/v3/pkg/operators/common/dsl"
"github.com/projectdiscovery/nuclei/v3/pkg/operators/extractors" "github.com/projectdiscovery/nuclei/v3/pkg/operators/extractors"
"github.com/projectdiscovery/nuclei/v3/pkg/operators/matchers" "github.com/projectdiscovery/nuclei/v3/pkg/operators/matchers"
"github.com/projectdiscovery/nuclei/v3/pkg/templates"
"github.com/projectdiscovery/nuclei/v3/pkg/templates/types" "github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
sliceutil "github.com/projectdiscovery/utils/slice" sliceutil "github.com/projectdiscovery/utils/slice"
) )
@ -42,7 +41,7 @@ var ErrExcluded = errors.New("the template was excluded")
// unless it is explicitly specified by user using the includeTags (matchAllows field). // unless it is explicitly specified by user using the includeTags (matchAllows field).
// Matching rule: (tag1 OR tag2...) AND (author1 OR author2...) AND (severity1 OR severity2...) AND (extraTags1 OR extraTags2...) // Matching rule: (tag1 OR tag2...) AND (author1 OR author2...) AND (severity1 OR severity2...) AND (extraTags1 OR extraTags2...)
// Returns true if the template matches the filter criteria, false otherwise. // Returns true if the template matches the filter criteria, false otherwise.
func (tagFilter *TagFilter) Match(template *templates.Template, extraTags []string) (bool, error) { func (tagFilter *TagFilter) Match(template *Template, extraTags []string) (bool, error) {
templateTags := template.Info.Tags.ToSlice() templateTags := template.Info.Tags.ToSlice()
for _, templateTag := range templateTags { for _, templateTag := range templateTags {
_, blocked := tagFilter.block[templateTag] _, blocked := tagFilter.block[templateTag]
@ -193,7 +192,7 @@ func isIdMatch(tagFilter *TagFilter, templateId string) bool {
return included && !excluded return included && !excluded
} }
func tryCollectConditionsMatchinfo(template *templates.Template) map[string]interface{} { func tryCollectConditionsMatchinfo(template *Template) map[string]interface{} {
// attempts to unwrap fields to their basic types // attempts to unwrap fields to their basic types
// mapping must be manual because of various abstraction layers, custom marshaling and forceful validation // mapping must be manual because of various abstraction layers, custom marshaling and forceful validation
parameters := map[string]interface{}{ parameters := map[string]interface{}{
@ -319,7 +318,7 @@ func collectExtractorTypes(extractors []*extractors.Extractor) []string {
return extractorTypes return extractorTypes
} }
func isConditionMatch(tagFilter *TagFilter, template *templates.Template) bool { func isConditionMatch(tagFilter *TagFilter, template *Template) bool {
if len(tagFilter.includeConditions) == 0 { if len(tagFilter.includeConditions) == 0 {
return true return true
} }
@ -365,7 +364,7 @@ type Config struct {
// New returns a tag filter for nuclei tag based execution // New returns a tag filter for nuclei tag based execution
// //
// It takes into account Tags, Severities, ExcludeSeverities, Authors, IncludeTags, ExcludeTags, Conditions. // It takes into account Tags, Severities, ExcludeSeverities, Authors, IncludeTags, ExcludeTags, Conditions.
func New(config *Config) (*TagFilter, error) { func NewTagFilter(config *Config) (*TagFilter, error) {
filter := &TagFilter{ filter := &TagFilter{
allowedTags: make(map[string]struct{}), allowedTags: make(map[string]struct{}),
authors: make(map[string]struct{}), authors: make(map[string]struct{}),

View File

@ -1,4 +1,4 @@
package filter package templates
import ( import (
"testing" "testing"
@ -8,14 +8,13 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice" "github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/dns" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/dns"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/http" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/http"
"github.com/projectdiscovery/nuclei/v3/pkg/templates"
"github.com/projectdiscovery/nuclei/v3/pkg/templates/types" "github.com/projectdiscovery/nuclei/v3/pkg/templates/types"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestTagBasedFilter(t *testing.T) { func TestTagBasedFilter(t *testing.T) {
newDummyTemplate := func(id string, tags, authors []string, severityValue severity.Severity, protocolType types.ProtocolType) *templates.Template { newDummyTemplate := func(id string, tags, authors []string, severityValue severity.Severity, protocolType types.ProtocolType) *Template {
dummyTemplate := &templates.Template{} dummyTemplate := &Template{}
if id != "" { if id != "" {
dummyTemplate.ID = id dummyTemplate.ID = id
} }
@ -35,7 +34,7 @@ func TestTagBasedFilter(t *testing.T) {
return dummyTemplate return dummyTemplate
} }
filter, err := New(&Config{ filter, err := NewTagFilter(&Config{
Tags: []string{"cves", "2021", "jira"}, Tags: []string{"cves", "2021", "jira"},
}) })
require.Nil(t, err) require.Nil(t, err)
@ -62,7 +61,7 @@ func TestTagBasedFilter(t *testing.T) {
}) })
t.Run("not-match-excludes", func(t *testing.T) { t.Run("not-match-excludes", func(t *testing.T) {
filter, err := New(&Config{ filter, err := NewTagFilter(&Config{
ExcludeTags: []string{"dos"}, ExcludeTags: []string{"dos"},
}) })
require.Nil(t, err) require.Nil(t, err)
@ -72,7 +71,7 @@ func TestTagBasedFilter(t *testing.T) {
require.Equal(t, ErrExcluded, err, "could not get correct error") require.Equal(t, ErrExcluded, err, "could not get correct error")
}) })
t.Run("match-includes", func(t *testing.T) { t.Run("match-includes", func(t *testing.T) {
filter, err := New(&Config{ filter, err := NewTagFilter(&Config{
Tags: []string{"cves", "fuzz"}, Tags: []string{"cves", "fuzz"},
ExcludeTags: []string{"dos", "fuzz"}, ExcludeTags: []string{"dos", "fuzz"},
IncludeTags: []string{"fuzz"}, IncludeTags: []string{"fuzz"},
@ -84,7 +83,7 @@ func TestTagBasedFilter(t *testing.T) {
require.True(t, matched, "could not get correct match") require.True(t, matched, "could not get correct match")
}) })
t.Run("match-includes", func(t *testing.T) { t.Run("match-includes", func(t *testing.T) {
filter, err := New(&Config{ filter, err := NewTagFilter(&Config{
IncludeTags: []string{"fuzz"}, IncludeTags: []string{"fuzz"},
ExcludeTags: []string{"fuzz"}, ExcludeTags: []string{"fuzz"},
}) })
@ -95,7 +94,7 @@ func TestTagBasedFilter(t *testing.T) {
require.True(t, matched, "could not get correct match") require.True(t, matched, "could not get correct match")
}) })
t.Run("match-author", func(t *testing.T) { t.Run("match-author", func(t *testing.T) {
filter, err := New(&Config{ filter, err := NewTagFilter(&Config{
Authors: []string{"pdteam"}, Authors: []string{"pdteam"},
}) })
require.Nil(t, err) require.Nil(t, err)
@ -104,7 +103,7 @@ func TestTagBasedFilter(t *testing.T) {
require.True(t, matched, "could not get correct match") require.True(t, matched, "could not get correct match")
}) })
t.Run("match-severity", func(t *testing.T) { t.Run("match-severity", func(t *testing.T) {
filter, err := New(&Config{ filter, err := NewTagFilter(&Config{
Severities: severity.Severities{severity.High}, Severities: severity.Severities{severity.High},
}) })
require.Nil(t, err) require.Nil(t, err)
@ -113,7 +112,7 @@ func TestTagBasedFilter(t *testing.T) {
require.True(t, matched, "could not get correct match") require.True(t, matched, "could not get correct match")
}) })
t.Run("match-id", func(t *testing.T) { t.Run("match-id", func(t *testing.T) {
filter, err := New(&Config{ filter, err := NewTagFilter(&Config{
IncludeIds: []string{"cve-test"}, IncludeIds: []string{"cve-test"},
}) })
require.Nil(t, err) require.Nil(t, err)
@ -122,7 +121,7 @@ func TestTagBasedFilter(t *testing.T) {
require.True(t, matched, "could not get correct match") require.True(t, matched, "could not get correct match")
}) })
t.Run("match-exclude-severity", func(t *testing.T) { t.Run("match-exclude-severity", func(t *testing.T) {
filter, err := New(&Config{ filter, err := NewTagFilter(&Config{
ExcludeSeverities: severity.Severities{severity.Low}, ExcludeSeverities: severity.Severities{severity.Low},
}) })
require.Nil(t, err) require.Nil(t, err)
@ -134,7 +133,7 @@ func TestTagBasedFilter(t *testing.T) {
require.False(t, matched, "could not get correct match") require.False(t, matched, "could not get correct match")
}) })
t.Run("match-exclude-with-tags", func(t *testing.T) { t.Run("match-exclude-with-tags", func(t *testing.T) {
filter, err := New(&Config{ filter, err := NewTagFilter(&Config{
Tags: []string{"tag"}, Tags: []string{"tag"},
ExcludeTags: []string{"another"}, ExcludeTags: []string{"another"},
}) })
@ -144,7 +143,7 @@ func TestTagBasedFilter(t *testing.T) {
require.False(t, matched, "could not get correct match") require.False(t, matched, "could not get correct match")
}) })
t.Run("match-conditions", func(t *testing.T) { t.Run("match-conditions", func(t *testing.T) {
filter, err := New(&Config{ filter, err := NewTagFilter(&Config{
Authors: []string{"pdteam"}, Authors: []string{"pdteam"},
Tags: []string{"jira"}, Tags: []string{"jira"},
Severities: severity.Severities{severity.High}, Severities: severity.Severities{severity.High},
@ -165,7 +164,7 @@ func TestTagBasedFilter(t *testing.T) {
require.False(t, matched, "could not get correct match") require.False(t, matched, "could not get correct match")
}) })
t.Run("match-type", func(t *testing.T) { t.Run("match-type", func(t *testing.T) {
filter, err := New(&Config{ filter, err := NewTagFilter(&Config{
Protocols: []types.ProtocolType{types.HTTPProtocol}, Protocols: []types.ProtocolType{types.HTTPProtocol},
}) })
require.Nil(t, err) require.Nil(t, err)
@ -175,7 +174,7 @@ func TestTagBasedFilter(t *testing.T) {
require.True(t, matched, "could not get correct match") require.True(t, matched, "could not get correct match")
}) })
t.Run("match-exclude-id", func(t *testing.T) { t.Run("match-exclude-id", func(t *testing.T) {
filter, err := New(&Config{ filter, err := NewTagFilter(&Config{
ExcludeIds: []string{"cve-test"}, ExcludeIds: []string{"cve-test"},
}) })
require.Nil(t, err) require.Nil(t, err)
@ -187,7 +186,7 @@ func TestTagBasedFilter(t *testing.T) {
require.False(t, matched, "could not get correct match") require.False(t, matched, "could not get correct match")
}) })
t.Run("match-exclude-type", func(t *testing.T) { t.Run("match-exclude-type", func(t *testing.T) {
filter, err := New(&Config{ filter, err := NewTagFilter(&Config{
ExcludeProtocols: []types.ProtocolType{types.HTTPProtocol}, ExcludeProtocols: []types.ProtocolType{types.HTTPProtocol},
}) })
require.Nil(t, err) require.Nil(t, err)
@ -267,9 +266,9 @@ func TestTagBasedFilter(t *testing.T) {
}) })
} }
func testAdvancedFiltering(t *testing.T, includeConditions []string, template *templates.Template, shouldError, shouldMatch bool) { func testAdvancedFiltering(t *testing.T, includeConditions []string, template *Template, shouldError, shouldMatch bool) {
// basic properties // basic properties
advancedFilter, err := New(&Config{IncludeConditions: includeConditions}) advancedFilter, err := NewTagFilter(&Config{IncludeConditions: includeConditions})
if shouldError { if shouldError {
require.NotNil(t, err) require.NotNil(t, err)
return return