simplify execution logic

This commit is contained in:
Mzack9999 2023-03-17 17:31:28 +01:00
parent 6a0db1c234
commit 1b7585476e

View File

@ -9,6 +9,7 @@ import (
"github.com/logrusorgru/aurora" "github.com/logrusorgru/aurora"
"github.com/projectdiscovery/nuclei/v2/pkg/testutils" "github.com/projectdiscovery/nuclei/v2/pkg/testutils"
sliceutil "github.com/projectdiscovery/utils/slice"
) )
var ( var (
@ -62,7 +63,9 @@ func main() {
os.Exit(1) os.Exit(1)
} }
failedTestTemplatePaths := runTests(toMap(toSlice(customTests))) customTestsList := normalizeSplit(customTests)
failedTestTemplatePaths := runTests(customTestsList)
if len(failedTestTemplatePaths) > 0 { if len(failedTestTemplatePaths) > 0 {
if githubAction { if githubAction {
@ -87,8 +90,8 @@ func debugTests() {
} }
} }
func runTests(customTemplatePaths map[string]struct{}) map[string]struct{} { func runTests(customTemplatePaths []string) []string {
failedTestTemplatePaths := map[string]struct{}{} var failedTestTemplatePaths []string
for proto, testCases := range protocolTests { for proto, testCases := range protocolTests {
if len(customTemplatePaths) == 0 { if len(customTemplatePaths) == 0 {
@ -96,9 +99,9 @@ func runTests(customTemplatePaths map[string]struct{}) map[string]struct{} {
} }
for templatePath, testCase := range testCases { for templatePath, testCase := range testCases {
if len(customTemplatePaths) == 0 || contains(customTemplatePaths, templatePath) { if len(customTemplatePaths) == 0 || sliceutil.Contains(customTemplatePaths, templatePath) {
if failedTemplatePath, err := execute(testCase, templatePath); err != nil { if failedTemplatePath, err := execute(testCase, templatePath); err != nil {
failedTestTemplatePaths[failedTemplatePath] = struct{}{} failedTestTemplatePaths = append(failedTestTemplatePaths, failedTemplatePath)
} }
} }
} }
@ -124,25 +127,6 @@ func expectResultsCount(results []string, expectedNumber int) error {
return nil return nil
} }
func toSlice(value string) []string { func normalizeSplit(str string) []string {
if strings.TrimSpace(value) == "" { return strings.Split(strings.TrimSpace(str), ",")
return []string{}
}
return strings.Split(value, ",")
}
func toMap(slice []string) map[string]struct{} {
result := make(map[string]struct{}, len(slice))
for _, value := range slice {
if _, ok := result[value]; !ok {
result[value] = struct{}{}
}
}
return result
}
func contains(input map[string]struct{}, value string) bool {
_, ok := input[value]
return ok
} }