refactor: replace date, time, time_format, time_to_string DSL functions to date_time

This commit is contained in:
forgedhallpass 2022-06-08 17:43:52 +03:00
parent a10d58c6d2
commit ef20e0711b
2 changed files with 169 additions and 133 deletions

View File

@ -47,6 +47,7 @@ var invalidDslFunctionMessageTemplate = "%w. correct method signature %q"
var dslFunctions map[string]dslFunction var dslFunctions map[string]dslFunction
var functionSignaturePattern = regexp.MustCompile(`(\w+)\s*\((?:([\w\d,\s]+)\s+([.\w\d{}&*]+))?\)([\s.\w\d{}&*]+)?`)
var dateFormatRegex = regexp.MustCompile("%([A-Za-z])") var dateFormatRegex = regexp.MustCompile("%([A-Za-z])")
type dslFunction struct { type dslFunction struct {
@ -155,66 +156,29 @@ func init() {
_ = reader.Close() _ = reader.Close()
return string(data), nil return string(data), nil
}), }),
"date": makeDslFunction(1, func(args ...interface{}) (interface{}, error) { "date_time": makeDslWithOptionalArgsFunction(
timeFormat := types.ToString(args[0]) "(dateTimeFormat string, optionalUnixTime interface{}) string",
timeFormatFragment := dateFormatRegex.FindAllStringSubmatch(timeFormat, -1) func(arguments ...interface{}) (interface{}, error) {
dateTimeFormat := types.ToString(arguments[0])
dateTimeFormatFragment := dateFormatRegex.FindAllStringSubmatch(dateTimeFormat, -1)
for _, currentFragment := range timeFormatFragment { argumentsSize := len(arguments)
if len(currentFragment) < 2 { if argumentsSize < 1 && argumentsSize > 2 {
continue return nil, errors.New("invalid number of arguments")
} }
now := time.Now()
prefixedFormatFragment := currentFragment[0]
switch currentFragment[1] {
case "Y", "y":
timeFormat = formatDateTime(timeFormat, prefixedFormatFragment, now.Year())
case "M", "m":
timeFormat = formatDateTime(timeFormat, prefixedFormatFragment, int(now.Month()))
case "D", "d":
timeFormat = formatDateTime(timeFormat, prefixedFormatFragment, now.Day())
default:
return nil, fmt.Errorf("invalid date format string: %s", prefixedFormatFragment)
}
}
return timeFormat, nil
}),
"time": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
timeFormat := types.ToString(args[0])
timeFormatFragment := dateFormatRegex.FindAllStringSubmatch(timeFormat, -1)
for _, currentFragment := range timeFormatFragment { currentTime, err := getCurrentTimeFromUserInput(arguments)
if len(currentFragment) < 2 { if err != nil {
continue return nil, err
} }
now := time.Now()
prefixedFormatFragment := currentFragment[0] if len(dateTimeFormatFragment) > 0 {
switch currentFragment[1] { return doSimpleTimeFormat(dateTimeFormatFragment, currentTime, dateTimeFormat)
case "H", "h": } else {
timeFormat = formatDateTime(timeFormat, prefixedFormatFragment, now.Hour()) return currentTime.Format(dateTimeFormat), nil
case "M", "m":
timeFormat = formatDateTime(timeFormat, prefixedFormatFragment, now.Minute())
case "S", "s":
timeFormat = formatDateTime(timeFormat, prefixedFormatFragment, now.Second())
default:
return nil, fmt.Errorf("invalid time format string: %s", prefixedFormatFragment)
} }
} },
return timeFormat, nil ),
}),
"time_format": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
t := time.Now()
return t.Format(args[0].(string)), nil
}),
"time_to_string": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
if got, ok := args[0].(time.Time); ok {
return got.String(), nil
}
if got, ok := args[0].(float64); ok {
seconds, nanoseconds := math.Modf(got)
return time.Unix(int64(seconds), int64(nanoseconds)).String(), nil
}
return nil, fmt.Errorf("invalid time format: %T", args[0])
}),
"base64_py": makeDslFunction(1, func(args ...interface{}) (interface{}, error) { "base64_py": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
// python encodes to base64 with lines of 76 bytes terminated by new line "\n" // python encodes to base64 with lines of 76 bytes terminated by new line "\n"
stdBase64 := base64.StdEncoding.EncodeToString([]byte(types.ToString(args[0]))) stdBase64 := base64.StdEncoding.EncodeToString([]byte(types.ToString(args[0])))
@ -549,37 +513,6 @@ func init() {
} }
} }
func toHexEncodedHash(hashToUse hash.Hash, data string) (interface{}, error) {
if _, err := hashToUse.Write([]byte(data)); err != nil {
return nil, err
}
return hex.EncodeToString(hashToUse.Sum(nil)), nil
}
func formatDateTime(inputFormat string, matchValue string, timeFragment int) string {
return strings.ReplaceAll(inputFormat, matchValue, appendSingleDigitZero(strconv.Itoa(timeFragment)))
}
// appendSingleDigitZero appends zero at front if not exists already doing two digit padding
func appendSingleDigitZero(value string) string {
if len(value) == 1 && (!strings.HasPrefix(value, "0") || value == "0") {
builder := &strings.Builder{}
builder.WriteRune('0')
builder.WriteString(value)
newVal := builder.String()
return newVal
}
return value
}
func createSignaturePart(numberOfParameters int) string {
params := make([]string, 0, numberOfParameters)
for i := 1; i <= numberOfParameters; i++ {
params = append(params, "arg"+strconv.Itoa(i))
}
return fmt.Sprintf("(%s interface{}) interface{}", strings.Join(params, ", "))
}
func makeDslWithOptionalArgsFunction(signaturePart string, dslFunctionLogic govaluate.ExpressionFunction) func(functionName string) dslFunction { func makeDslWithOptionalArgsFunction(signaturePart string, dslFunctionLogic govaluate.ExpressionFunction) func(functionName string) dslFunction {
return func(functionName string) dslFunction { return func(functionName string) dslFunction {
return dslFunction{ return dslFunction{
@ -604,6 +537,14 @@ func makeDslFunction(numberOfParameters int, dslFunctionLogic govaluate.Expressi
} }
} }
func createSignaturePart(numberOfParameters int) string {
params := make([]string, 0, numberOfParameters)
for i := 1; i <= numberOfParameters; i++ {
params = append(params, "arg"+strconv.Itoa(i))
}
return fmt.Sprintf("(%s interface{}) interface{}", strings.Join(params, ", "))
}
// HelperFunctions returns the dsl helper functions // HelperFunctions returns the dsl helper functions
func HelperFunctions() map[string]govaluate.ExpressionFunction { func HelperFunctions() map[string]govaluate.ExpressionFunction {
helperFunctions := make(map[string]govaluate.ExpressionFunction, len(dslFunctions)) helperFunctions := make(map[string]govaluate.ExpressionFunction, len(dslFunctions))
@ -657,8 +598,6 @@ func getDslFunctionSignatures() []string {
return result return result
} }
var functionSignaturePattern = regexp.MustCompile(`(\w+)\s*\((?:([\w\d,\s]+)\s+([.\w\d{}&*]+))?\)([\s.\w\d{}&*]+)?`)
func colorizeDslFunctionSignatures() []string { func colorizeDslFunctionSignatures() []string {
signatures := getDslFunctionSignatures() signatures := getDslFunctionSignatures()
@ -724,3 +663,75 @@ func randSeq(base string, n int) string {
} }
return string(b) return string(b)
} }
func toHexEncodedHash(hashToUse hash.Hash, data string) (interface{}, error) {
if _, err := hashToUse.Write([]byte(data)); err != nil {
return nil, err
}
return hex.EncodeToString(hashToUse.Sum(nil)), nil
}
func doSimpleTimeFormat(dateTimeFormatFragment [][]string, currentTime time.Time, dateTimeFormat string) (interface{}, error) {
for _, currentFragment := range dateTimeFormatFragment {
if len(currentFragment) < 2 {
continue
}
prefixedFormatFragment := currentFragment[0]
switch currentFragment[1] {
case "Y", "y":
dateTimeFormat = formatDateTime(dateTimeFormat, prefixedFormatFragment, currentTime.Year())
case "M":
dateTimeFormat = formatDateTime(dateTimeFormat, prefixedFormatFragment, int(currentTime.Month()))
case "D", "d":
dateTimeFormat = formatDateTime(dateTimeFormat, prefixedFormatFragment, currentTime.Day())
case "H", "h":
dateTimeFormat = formatDateTime(dateTimeFormat, prefixedFormatFragment, currentTime.Hour())
case "m":
dateTimeFormat = formatDateTime(dateTimeFormat, prefixedFormatFragment, currentTime.Minute())
case "S", "s":
dateTimeFormat = formatDateTime(dateTimeFormat, prefixedFormatFragment, currentTime.Second())
default:
return nil, fmt.Errorf("invalid date time format string: %s", prefixedFormatFragment)
}
}
return dateTimeFormat, nil
}
func getCurrentTimeFromUserInput(arguments []interface{}) (time.Time, error) {
var currentTime time.Time
if len(arguments) == 2 {
switch inputUnixTime := arguments[1].(type) {
case time.Time:
currentTime = inputUnixTime
case string:
unixTime, err := strconv.ParseInt(inputUnixTime, 10, 64)
if err != nil {
return time.Time{}, errors.New("invalid argument type")
}
currentTime = time.Unix(unixTime, 0)
case int64, float64:
currentTime = time.Unix(int64(inputUnixTime.(float64)), 0)
default:
return time.Time{}, errors.New("invalid argument type")
}
} else {
currentTime = time.Now()
}
return currentTime, nil
}
func formatDateTime(inputFormat string, matchValue string, timeFragment int) string {
return strings.ReplaceAll(inputFormat, matchValue, appendSingleDigitZero(strconv.Itoa(timeFragment)))
}
// appendSingleDigitZero appends zero at front if not exists already doing two digit padding
func appendSingleDigitZero(value string) string {
if len(value) == 1 && (!strings.HasPrefix(value, "0") || value == "0") {
builder := &strings.Builder{}
builder.WriteRune('0')
builder.WriteString(value)
newVal := builder.String()
return newVal
}
return value
}

View File

@ -51,16 +51,44 @@ func TestDSLGzipSerialize(t *testing.T) {
require.Equal(t, "hello world", data.(string), "could not get gzip encoded data") require.Equal(t, "hello world", data.(string), "could not get gzip encoded data")
} }
func TestTimeToStringDSLFunction(t *testing.T) { func TestDateTimeDSLFunction(t *testing.T) {
compiled, err := govaluate.NewEvaluableExpressionWithFunctions("time_to_string(data)", HelperFunctions())
require.Nil(t, err, "could not compile encoder") testDateTimeFormat := func(t *testing.T, dateTimeFormat string, dateTimeFunction *govaluate.EvaluableExpression, expectedFormattedTime string, currentUnixTime int64) {
dslFunctionParameters := map[string]interface{}{"dateTimeFormat": dateTimeFormat}
if currentUnixTime != 0 {
dslFunctionParameters["unixTime"] = currentUnixTime
}
result, err := dateTimeFunction.Evaluate(dslFunctionParameters)
data := time.Now()
result, err := compiled.Evaluate(map[string]interface{}{"data": data})
require.Nil(t, err, "could not evaluate compare time") require.Nil(t, err, "could not evaluate compare time")
require.Equal(t, data.String(), result.(string), "could not get correct time format string") require.Equal(t, expectedFormattedTime, result.(string), "could not get correct time format string")
}
t.Run("with Unix time", func(t *testing.T) {
dateTimeFunction, err := govaluate.NewEvaluableExpressionWithFunctions("date_time(dateTimeFormat)", HelperFunctions())
require.Nil(t, err, "could not compile encoder")
currentTime := time.Now()
expectedFormattedTime := currentTime.Format("02-01-2006 15:04")
testDateTimeFormat(t, "02-01-2006 15:04", dateTimeFunction, expectedFormattedTime, 0)
testDateTimeFormat(t, "%D-%M-%Y %H:%m", dateTimeFunction, expectedFormattedTime, 0)
})
t.Run("without Unix time", func(t *testing.T) {
dateTimeFunction, err := govaluate.NewEvaluableExpressionWithFunctions("date_time(dateTimeFormat, unixTime)", HelperFunctions())
require.Nil(t, err, "could not compile encoder")
currentTime := time.Now()
currentUnixTime := currentTime.Unix()
expectedFormattedTime := currentTime.Format("02-01-2006 15:04")
testDateTimeFormat(t, "02-01-2006 15:04", dateTimeFunction, expectedFormattedTime, currentUnixTime)
testDateTimeFormat(t, "%D-%M-%Y %H:%m", dateTimeFunction, expectedFormattedTime, currentUnixTime)
})
} }
func TestDslFunctionSignatures(t *testing.T) { func TestDslFunctionSignatures(t *testing.T) {
type testCase struct { type testCase struct {
methodName string methodName string
@ -111,7 +139,7 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) {
compare_versions(firstVersion, constraints ...string) bool compare_versions(firstVersion, constraints ...string) bool
concat(args ...interface{}) string concat(args ...interface{}) string
contains(arg1, arg2 interface{}) interface{} contains(arg1, arg2 interface{}) interface{}
date(arg1 interface{}) interface{} date_time(dateTimeFormat string, optionalUnixTime interface{}) string
dec_to_hex(arg1 interface{}) interface{} dec_to_hex(arg1 interface{}) interface{}
generate_java_gadget(arg1, arg2, arg3 interface{}) interface{} generate_java_gadget(arg1, arg2, arg3 interface{}) interface{}
gzip(arg1 interface{}) interface{} gzip(arg1 interface{}) interface{}
@ -141,9 +169,6 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) {
reverse(arg1 interface{}) interface{} reverse(arg1 interface{}) interface{}
sha1(arg1 interface{}) interface{} sha1(arg1 interface{}) interface{}
sha256(arg1 interface{}) interface{} sha256(arg1 interface{}) interface{}
time(arg1 interface{}) interface{}
time_format(arg1 interface{}) interface{}
time_to_string(arg1 interface{}) interface{}
to_lower(arg1 interface{}) interface{} to_lower(arg1 interface{}) interface{}
to_number(arg1 interface{}) interface{} to_number(arg1 interface{}) interface{}
to_string(arg1 interface{}) interface{} to_string(arg1 interface{}) interface{}
@ -183,8 +208,9 @@ func TestDslExpressions(t *testing.T) {
`hex_encode("aa")`: "6161", `hex_encode("aa")`: "6161",
`html_escape("<body>test</body>")`: "&lt;body&gt;test&lt;/body&gt;", `html_escape("<body>test</body>")`: "&lt;body&gt;test&lt;/body&gt;",
`html_unescape("&lt;body&gt;test&lt;/body&gt;")`: "<body>test</body>", `html_unescape("&lt;body&gt;test&lt;/body&gt;")`: "<body>test</body>",
`date("%Y-%M-%D")`: fmt.Sprintf("%02d-%02d-%02d", now.Year(), now.Month(), now.Day()), `date_time("%Y-%M-%D")`: fmt.Sprintf("%02d-%02d-%02d", now.Year(), now.Month(), now.Day()),
`time("%H-%M")`: fmt.Sprintf("%02d-%02d", now.Hour(), now.Minute()), `date_time("%H-%m")`: fmt.Sprintf("%02d-%02d", now.Hour(), now.Minute()),
`date_time("02-01-2006 15:04")`: now.Format("02-01-2006 15:04"),
`md5("Hello")`: "8b1a9953c4611296a827abf8c47804d7", `md5("Hello")`: "8b1a9953c4611296a827abf8c47804d7",
`md5(1234)`: "81dc9bdb52d04dc20036dbd8313ed055", `md5(1234)`: "81dc9bdb52d04dc20036dbd8313ed055",
`mmh3("Hello")`: "316307400", `mmh3("Hello")`: "316307400",
@ -232,7 +258,6 @@ func TestDslExpressions(t *testing.T) {
`compare_versions('v1.0.0', '>v0.0.1', '<v1.0.1')`: true, `compare_versions('v1.0.0', '>v0.0.1', '<v1.0.1')`: true,
`hmac('sha1', 'test', 'scrt')`: "8856b111056d946d5c6c92a21b43c233596623c6", `hmac('sha1', 'test', 'scrt')`: "8856b111056d946d5c6c92a21b43c233596623c6",
`hmac('sha256', 'test', 'scrt')`: "1f1bff5574f18426eb376d6dd5368a754e67a798aa2074644d5e3fd4c90c7a92", `hmac('sha256', 'test', 'scrt')`: "1f1bff5574f18426eb376d6dd5368a754e67a798aa2074644d5e3fd4c90c7a92",
`time_format("02-01-2006 15:04")`: now.Format("02-01-2006 15:04"),
} }
for dslExpression, expectedResult := range dslExpressions { for dslExpression, expectedResult := range dslExpressions {