diff --git a/api/models/schema_test.go b/api/models/schema_test.go new file mode 100644 index 0000000..f5c10a0 --- /dev/null +++ b/api/models/schema_test.go @@ -0,0 +1,346 @@ +package models + +import ( + "testing" + + "github.com/awatercolorpen/olap-sql/api/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ---- Graph ---- + +func TestGraph_GetTree(t *testing.T) { + t.Run("single node", func(t *testing.T) { + g := Graph{"a": {"b", "c"}, "b": nil, "c": nil} + tree := g.GetTree("a") + assert.Equal(t, []string{"b", "c"}, tree["a"]) + }) + + t.Run("multi level", func(t *testing.T) { + g := Graph{"a": {"b"}, "b": {"c"}, "c": nil} + tree := g.GetTree("a") + assert.Equal(t, []string{"b"}, tree["a"]) + assert.Equal(t, []string{"c"}, tree["b"]) + }) + + t.Run("empty start", func(t *testing.T) { + g := Graph{} + tree := g.GetTree("missing") + assert.Empty(t, tree) + }) +} + +// ---- DataSet ---- + +func TestDataSet_GetKey(t *testing.T) { + ds := &DataSet{Name: "my_dataset"} + assert.Equal(t, "my_dataset", ds.GetKey()) +} + +func TestDataSet_GetCurrent(t *testing.T) { + ds := &DataSet{DataSource: "fact_orders"} + assert.Equal(t, "fact_orders", ds.GetCurrent()) +} + +// ---- JoinPair ---- + +func TestJoinPair_GetJoinType(t *testing.T) { + t.Run("get from first", func(t *testing.T) { + pair := JoinPair{ + {DataSource: "a", JoinType: "LEFT JOIN"}, + {DataSource: "b", JoinType: ""}, + } + assert.Equal(t, "LEFT JOIN", pair.GetJoinType()) + }) + + t.Run("get from second", func(t *testing.T) { + pair := JoinPair{ + {DataSource: "a", JoinType: ""}, + {DataSource: "b", JoinType: "INNER JOIN"}, + } + assert.Equal(t, "INNER JOIN", pair.GetJoinType()) + }) + + t.Run("both empty", func(t *testing.T) { + pair := JoinPair{ + {DataSource: "a"}, + {DataSource: "b"}, + } + assert.Equal(t, "", pair.GetJoinType()) + }) +} + +func TestJoinPair_IsValid(t *testing.T) { + t.Run("valid pair", func(t *testing.T) { + pair := JoinPair{ + {DataSource: "a", Dimension: []string{"id"}}, + {DataSource: "b", Dimension: []string{"user_id"}}, + } + assert.NoError(t, pair.IsValid()) + }) + + t.Run("wrong length", func(t *testing.T) { + pair := JoinPair{ + {DataSource: "a"}, + } + assert.Error(t, pair.IsValid()) + }) + + t.Run("mismatched dimensions", func(t *testing.T) { + pair := JoinPair{ + {DataSource: "a", Dimension: []string{"id", "type"}}, + {DataSource: "b", Dimension: []string{"user_id"}}, + } + assert.Error(t, pair.IsValid()) + }) +} + +// ---- DimensionJoins ---- + +func TestDimensionJoins_IsValid(t *testing.T) { + t.Run("empty", func(t *testing.T) { + dj := DimensionJoins{} + assert.Error(t, dj.IsValid()) + }) + + t.Run("valid", func(t *testing.T) { + dj := DimensionJoins{ + { + {DataSource: "fact", Dimension: []string{"user_id"}}, + {DataSource: "dim_user", Dimension: []string{"id"}}, + }, + } + assert.NoError(t, dj.IsValid()) + }) +} + +func TestDimensionJoins_GetDependencyTree(t *testing.T) { + t.Run("simple tree", func(t *testing.T) { + dj := DimensionJoins{ + { + {DataSource: "fact_orders", Dimension: []string{"user_id"}}, + {DataSource: "dim_users", Dimension: []string{"id"}}, + }, + } + g, err := dj.GetDependencyTree("fact_orders") + require.NoError(t, err) + assert.NotEmpty(t, g) + }) +} + +// ---- MergeJoin ---- + +func TestMergeJoin_IsValid(t *testing.T) { + t.Run("too few", func(t *testing.T) { + mj := MergeJoin{ + {DataSource: "fact"}, + {DataSource: "dim"}, + } + assert.Error(t, mj.IsValid("fact")) + }) + + t.Run("wrong first source", func(t *testing.T) { + mj := MergeJoin{ + {DataSource: "other", Dimension: []string{"id"}}, + {DataSource: "dim1", Dimension: []string{"id"}}, + {DataSource: "dim2", Dimension: []string{"id"}}, + } + assert.Error(t, mj.IsValid("fact")) + }) + + t.Run("valid", func(t *testing.T) { + mj := MergeJoin{ + {DataSource: "fact", Dimension: []string{"id"}}, + {DataSource: "dim1", Dimension: []string{"id"}}, + {DataSource: "dim2", Dimension: []string{"id"}}, + } + assert.NoError(t, mj.IsValid("fact")) + }) + + t.Run("dimension mismatch", func(t *testing.T) { + mj := MergeJoin{ + {DataSource: "fact", Dimension: []string{"id", "type"}}, + {DataSource: "dim1", Dimension: []string{"id"}}, + {DataSource: "dim2", Dimension: []string{"id"}}, + } + assert.Error(t, mj.IsValid("fact")) + }) +} + +func TestMergeJoin_GetDependencyTree(t *testing.T) { + mj := MergeJoin{ + {DataSource: "fact"}, + {DataSource: "dim1"}, + {DataSource: "dim2"}, + } + g, err := mj.GetDependencyTree("fact") + require.NoError(t, err) + assert.Contains(t, g["fact"], "dim1") + assert.Contains(t, g["fact"], "dim2") +} + +// ---- DataSource ---- + +func TestDataSource_GetKey(t *testing.T) { + ds := &DataSource{Name: "fact_orders"} + assert.Equal(t, "fact_orders", ds.GetKey()) +} + +func TestDataSource_IsFact(t *testing.T) { + tests := []struct { + dsType types.DataSourceType + expected bool + }{ + {types.DataSourceTypeFact, true}, + {types.DataSourceTypeFactDimensionJoin, true}, + {types.DataSourceTypeMergeJoin, true}, + {types.DataSourceTypeDimension, false}, + } + for _, tt := range tests { + ds := &DataSource{Type: tt.dsType} + assert.Equal(t, tt.expected, ds.IsFact(), "type=%v", tt.dsType) + } +} + +func TestDataSource_IsDimension(t *testing.T) { + assert.True(t, (&DataSource{Type: types.DataSourceTypeDimension}).IsDimension()) + assert.False(t, (&DataSource{Type: types.DataSourceTypeFact}).IsDimension()) +} + +func TestDataSource_IsValid(t *testing.T) { + t.Run("fact valid", func(t *testing.T) { + ds := &DataSource{Type: types.DataSourceTypeFact, Name: "fact"} + assert.NoError(t, ds.IsValid()) + }) + + t.Run("dimension valid", func(t *testing.T) { + ds := &DataSource{Type: types.DataSourceTypeDimension, Name: "dim"} + assert.NoError(t, ds.IsValid()) + }) + + t.Run("fact_dimension_join valid", func(t *testing.T) { + ds := &DataSource{ + Name: "fact", + Type: types.DataSourceTypeFactDimensionJoin, + DimensionJoin: DimensionJoins{ + { + {DataSource: "fact", Dimension: []string{"user_id"}}, + {DataSource: "dim_user", Dimension: []string{"id"}}, + }, + }, + } + assert.NoError(t, ds.IsValid()) + }) + + t.Run("merge_join valid", func(t *testing.T) { + ds := &DataSource{ + Name: "fact", + Type: types.DataSourceTypeMergeJoin, + MergeJoin: MergeJoin{ + {DataSource: "fact", Dimension: []string{"id"}}, + {DataSource: "dim1", Dimension: []string{"id"}}, + {DataSource: "dim2", Dimension: []string{"id"}}, + }, + } + assert.NoError(t, ds.IsValid()) + }) + + t.Run("unknown type", func(t *testing.T) { + ds := &DataSource{Type: "unknown", Name: "x"} + assert.Error(t, ds.IsValid()) + }) +} + +func TestDataSource_GetDependencyTree(t *testing.T) { + t.Run("fact", func(t *testing.T) { + ds := &DataSource{Type: types.DataSourceTypeFact, Name: "fact"} + g, err := ds.GetDependencyTree() + require.NoError(t, err) + assert.Contains(t, g, "fact") + }) + + t.Run("unknown type error", func(t *testing.T) { + ds := &DataSource{Type: "unknown", Name: "x"} + _, err := ds.GetDependencyTree() + assert.Error(t, err) + }) +} + +func TestDataSource_GetGetDependencyKey(t *testing.T) { + t.Run("dimension join keys", func(t *testing.T) { + ds := &DataSource{ + Name: "fact", + Type: types.DataSourceTypeFactDimensionJoin, + DimensionJoin: DimensionJoins{ + { + {DataSource: "fact", Dimension: []string{"uid"}}, + {DataSource: "dim_user", Dimension: []string{"id"}}, + }, + }, + } + keys := ds.GetGetDependencyKey() + assert.Contains(t, keys, "fact") + assert.Contains(t, keys, "dim_user") + }) + + t.Run("merge join keys", func(t *testing.T) { + ds := &DataSource{ + Name: "fact", + Type: types.DataSourceTypeMergeJoin, + MergeJoin: MergeJoin{ + {DataSource: "fact"}, + {DataSource: "dim1"}, + {DataSource: "dim2"}, + }, + } + keys := ds.GetGetDependencyKey() + assert.Contains(t, keys, "dim1") + assert.Contains(t, keys, "dim2") + }) +} + +// ---- DataSources ---- + +func TestDataSources_KeyIndex(t *testing.T) { + dss := DataSources{ + {Name: "fact_orders", Type: types.DataSourceTypeFact}, + {Name: "dim_users", Type: types.DataSourceTypeDimension}, + } + idx := dss.KeyIndex() + assert.Len(t, idx, 2) + assert.Equal(t, "fact_orders", idx["fact_orders"].Name) + assert.Equal(t, "dim_users", idx["dim_users"].Name) +} + +// ---- Dimension ---- + +func TestDimension_GetKey(t *testing.T) { + d := &Dimension{DataSource: "fact_orders", Name: "city"} + assert.Equal(t, "fact_orders.city", d.GetKey()) +} + +func TestDimension_GetDependency(t *testing.T) { + d := &Dimension{Dependency: []string{"a", "b"}} + assert.Equal(t, []string{"a", "b"}, d.GetDependency()) +} + +// ---- Metric ---- + +func TestMetric_GetKey(t *testing.T) { + m := &Metric{DataSource: "fact_orders", Name: "revenue"} + assert.Equal(t, "fact_orders.revenue", m.GetKey()) +} + +func TestMetric_GetDependency(t *testing.T) { + m := &Metric{Dependency: []string{"x", "y"}} + assert.Equal(t, []string{"x", "y"}, m.GetDependency()) +} + +// ---- GetNameFromKey ---- + +func TestGetNameFromKey(t *testing.T) { + assert.Equal(t, "city", GetNameFromKey("fact_orders.city")) + assert.Equal(t, "revenue", GetNameFromKey("fact.revenue")) + assert.Equal(t, "standalone", GetNameFromKey("standalone")) +} diff --git a/api/types/column_test.go b/api/types/column_test.go new file mode 100644 index 0000000..37e36e7 --- /dev/null +++ b/api/types/column_test.go @@ -0,0 +1,247 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSingleCol_GetExpression_Value(t *testing.T) { + col := &SingleCol{ + DBType: DBTypeSQLite, + Table: "orders", + Name: "status", + Alias: "order_status", + Type: ColumnTypeValue, + } + assert.Equal(t, "`orders`.`status`", col.GetExpression()) + assert.Equal(t, "order_status", col.GetAlias()) +} + +func TestSingleCol_GetExpression_Count(t *testing.T) { + t.Run("count star", func(t *testing.T) { + col := &SingleCol{ + Table: "orders", + Name: "*", + Alias: "cnt", + Type: ColumnTypeCount, + } + assert.Equal(t, "COUNT(*)", col.GetExpression()) + }) + + t.Run("count field", func(t *testing.T) { + col := &SingleCol{ + Table: "orders", + Name: "id", + Alias: "cnt", + Type: ColumnTypeCount, + } + assert.Equal(t, "COUNT( `orders`.`id` )", col.GetExpression()) + }) +} + +func TestSingleCol_GetExpression_DistinctCount(t *testing.T) { + t.Run("no filter", func(t *testing.T) { + col := &SingleCol{ + Table: "orders", + Name: "user_id", + Alias: "uv", + Type: ColumnTypeDistinctCount, + } + assert.Equal(t, "1.0 * COUNT(DISTINCT `orders`.`user_id` )", col.GetExpression()) + }) + + t.Run("with clickhouse filter", func(t *testing.T) { + filter := &Filter{ + OperatorType: FilterOperatorTypeEquals, + ValueType: ValueTypeString, + Name: "type", + Value: []any{"vip"}, + } + col := &SingleCol{ + DBType: DBTypeClickHouse, + Table: "orders", + Name: "user_id", + Alias: "vip_uv", + Type: ColumnTypeDistinctCount, + Filter: filter, + } + expr := col.GetExpression() + assert.Contains(t, expr, "COUNT(DISTINCT") + assert.Contains(t, expr, "IF(") + }) +} + +func TestSingleCol_GetExpression_Sum(t *testing.T) { + t.Run("no filter", func(t *testing.T) { + col := &SingleCol{ + Table: "orders", + Name: "amount", + Alias: "total", + Type: ColumnTypeSum, + } + assert.Equal(t, "1.0 * SUM(`orders`.`amount`)", col.GetExpression()) + }) + + t.Run("with sqlite filter", func(t *testing.T) { + filter := &Filter{ + OperatorType: FilterOperatorTypeEquals, + ValueType: ValueTypeString, + Name: "status", + Value: []any{"paid"}, + } + col := &SingleCol{ + DBType: DBTypeSQLite, + Table: "orders", + Name: "amount", + Alias: "paid_amount", + Type: ColumnTypeSum, + Filter: filter, + } + expr := col.GetExpression() + assert.Contains(t, expr, "SUM(") + assert.Contains(t, expr, "IIF(") + }) +} + +func TestSingleCol_GetExpression_Unsupported(t *testing.T) { + col := &SingleCol{ + Table: "t", + Name: "x", + Type: "unknown_type", + } + expr := col.GetExpression() + assert.Contains(t, expr, "unsupported type") +} + +func TestSingleCol_GetExpression_EmptyName(t *testing.T) { + // When Name is empty, getSimpleName uses Alias + col := &SingleCol{ + Table: "orders", + Name: "", + Alias: "my_alias", + Type: ColumnTypeValue, + } + assert.Equal(t, "`orders`.`my_alias`", col.GetExpression()) +} + +func TestSingleCol_GetIfExpression(t *testing.T) { + filter := &Filter{ + OperatorType: FilterOperatorTypeEquals, + ValueType: ValueTypeString, + Name: "status", + Value: []any{"active"}, + } + + t.Run("clickhouse", func(t *testing.T) { + col := &SingleCol{ + DBType: DBTypeClickHouse, + Table: "users", + Name: "id", + Filter: filter, + } + expr, err := col.GetIfExpression() + require.NoError(t, err) + assert.Contains(t, expr, "IF(") + assert.Contains(t, expr, "`users`.`id`") + }) + + t.Run("sqlite", func(t *testing.T) { + col := &SingleCol{ + DBType: DBTypeSQLite, + Table: "users", + Name: "id", + Filter: filter, + } + expr, err := col.GetIfExpression() + require.NoError(t, err) + assert.Contains(t, expr, "IIF(") + }) + + t.Run("unsupported dbtype", func(t *testing.T) { + col := &SingleCol{ + DBType: "unknown_db", + Table: "users", + Name: "id", + Filter: filter, + } + _, err := col.GetIfExpression() + assert.Error(t, err) + }) +} + +func TestArithmeticCol_GetExpression(t *testing.T) { + left := &SingleCol{Table: "t", Name: "a", Alias: "a", Type: ColumnTypeValue} + right := &SingleCol{Table: "t", Name: "b", Alias: "b", Type: ColumnTypeValue} + + t.Run("add", func(t *testing.T) { + col := &ArithmeticCol{ + Column: []Column{left, right}, + Alias: "sum_ab", + Type: ColumnTypeAdd, + } + expr := col.GetExpression() + assert.Contains(t, expr, "+") + assert.Contains(t, expr, "IFNULL") + }) + + t.Run("subtract", func(t *testing.T) { + col := &ArithmeticCol{ + Column: []Column{left, right}, + Alias: "diff_ab", + Type: ColumnTypeSubtract, + } + expr := col.GetExpression() + assert.Contains(t, expr, "-") + }) + + t.Run("multiply", func(t *testing.T) { + col := &ArithmeticCol{ + Column: []Column{left, right}, + Alias: "mul_ab", + Type: ColumnTypeMultiply, + } + expr := col.GetExpression() + assert.Contains(t, expr, "*") + }) + + t.Run("divide", func(t *testing.T) { + col := &ArithmeticCol{ + Column: []Column{left, right}, + Alias: "div_ab", + Type: ColumnTypeDivide, + } + expr := col.GetExpression() + assert.Contains(t, expr, "/") + assert.Contains(t, expr, "NULLIF") + }) + + t.Run("as", func(t *testing.T) { + col := &ArithmeticCol{ + Column: []Column{left}, + Alias: "as_a", + Type: ColumnTypeAs, + } + expr := col.GetExpression() + assert.Contains(t, expr, "`t`.`a`") + }) + + t.Run("alias", func(t *testing.T) { + col := &ArithmeticCol{ + Column: []Column{left, right}, + Alias: "result", + Type: ColumnTypeAdd, + } + assert.Equal(t, "result", col.GetAlias()) + }) +} + +func TestExpressionCol(t *testing.T) { + col := &ExpressionCol{ + Expression: "CUSTOM_FUNC(x, y)", + Alias: "custom", + } + assert.Equal(t, "CUSTOM_FUNC(x, y)", col.GetExpression()) + assert.Equal(t, "custom", col.GetAlias()) +} diff --git a/api/types/filter_test.go b/api/types/filter_test.go new file mode 100644 index 0000000..a51159d --- /dev/null +++ b/api/types/filter_test.go @@ -0,0 +1,253 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFilterOperatorType_IsTree(t *testing.T) { + tests := []struct { + op FilterOperatorType + expected bool + }{ + {FilterOperatorTypeAnd, true}, + {FilterOperatorTypeOr, true}, + {FilterOperatorTypeEquals, false}, + {FilterOperatorTypeIn, false}, + {FilterOperatorTypeUnknown, false}, + } + for _, tt := range tests { + assert.Equal(t, tt.expected, tt.op.IsTree(), "op=%v", tt.op) + } +} + +func TestFilter_Expression_SimpleOperators(t *testing.T) { + tests := []struct { + name string + filter *Filter + expected string + }{ + { + name: "equals string", + filter: &Filter{ + OperatorType: FilterOperatorTypeEquals, + ValueType: ValueTypeString, + Name: "city", + Value: []any{"beijing"}, + }, + expected: "city = 'beijing'", + }, + { + name: "equals integer", + filter: &Filter{ + OperatorType: FilterOperatorTypeEquals, + ValueType: ValueTypeInteger, + Name: "age", + Value: []any{18}, + }, + expected: "age = 18", + }, + { + name: "in string list", + filter: &Filter{ + OperatorType: FilterOperatorTypeIn, + ValueType: ValueTypeString, + Name: "status", + Value: []any{"active", "pending"}, + }, + expected: "status IN ('active', 'pending')", + }, + { + name: "not in", + filter: &Filter{ + OperatorType: FilterOperatorTypeNotIn, + ValueType: ValueTypeString, + Name: "type", + Value: []any{"spam"}, + }, + expected: "type NOT IN ('spam')", + }, + { + name: "less equals", + filter: &Filter{ + OperatorType: FilterOperatorTypeLessEquals, + ValueType: ValueTypeFloat, + Name: "score", + Value: []any{100.0}, + }, + expected: "score <= 100", + }, + { + name: "less", + filter: &Filter{ + OperatorType: FilterOperatorTypeLess, + ValueType: ValueTypeInteger, + Name: "count", + Value: []any{50}, + }, + expected: "count < 50", + }, + { + name: "greater equals", + filter: &Filter{ + OperatorType: FilterOperatorTypeGreaterEquals, + ValueType: ValueTypeInteger, + Name: "rank", + Value: []any{1}, + }, + expected: "rank >= 1", + }, + { + name: "greater", + filter: &Filter{ + OperatorType: FilterOperatorTypeGreater, + ValueType: ValueTypeInteger, + Name: "id", + Value: []any{0}, + }, + expected: "id > 0", + }, + { + name: "like", + filter: &Filter{ + OperatorType: FilterOperatorTypeLike, + ValueType: ValueTypeString, + Name: "name", + Value: []any{"%test%"}, + }, + expected: "name LIKE '%test%'", + }, + { + name: "has", + filter: &Filter{ + OperatorType: FilterOperatorTypeHas, + ValueType: ValueTypeString, + Name: "tags", + Value: []any{"go"}, + }, + expected: "has(tags, 'go')", + }, + { + // FilterOperatorTypeExpression passes value through valueToStringSlice first, + // so string values get quoted. Use ValueTypeUnknown with a numeric-like string + // if you need a raw expression, or accept the quoting behavior. + name: "expression with integer value", + filter: &Filter{ + OperatorType: FilterOperatorTypeExpression, + ValueType: ValueTypeInteger, + Name: "x", + Value: []any{42}, + }, + expected: "42", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.filter.Expression() + require.NoError(t, err) + assert.Equal(t, tt.expected, got) + }) + } +} + +func TestFilter_Expression_TreeOperators(t *testing.T) { + t.Run("and", func(t *testing.T) { + f := &Filter{ + OperatorType: FilterOperatorTypeAnd, + Children: []*Filter{ + {OperatorType: FilterOperatorTypeEquals, ValueType: ValueTypeString, Name: "city", Value: []any{"bj"}}, + {OperatorType: FilterOperatorTypeGreater, ValueType: ValueTypeInteger, Name: "age", Value: []any{18}}, + }, + } + got, err := f.Expression() + require.NoError(t, err) + assert.Equal(t, "( city = 'bj' AND age > 18 )", got) + }) + + t.Run("or", func(t *testing.T) { + f := &Filter{ + OperatorType: FilterOperatorTypeOr, + Children: []*Filter{ + {OperatorType: FilterOperatorTypeEquals, ValueType: ValueTypeString, Name: "type", Value: []any{"A"}}, + {OperatorType: FilterOperatorTypeEquals, ValueType: ValueTypeString, Name: "type", Value: []any{"B"}}, + }, + } + got, err := f.Expression() + require.NoError(t, err) + assert.Equal(t, "( type = 'A' OR type = 'B' )", got) + }) +} + +func TestFilter_Expression_Errors(t *testing.T) { + t.Run("unknown operator", func(t *testing.T) { + f := &Filter{ + OperatorType: FilterOperatorTypeUnknown, + ValueType: ValueTypeString, + Name: "x", + Value: []any{"v"}, + } + _, err := f.Expression() + assert.Error(t, err) + }) + + t.Run("unsupported value type", func(t *testing.T) { + f := &Filter{ + OperatorType: FilterOperatorTypeEquals, + ValueType: "UNSUPPORTED_TYPE", + Name: "x", + Value: []any{"v"}, + } + _, err := f.Expression() + assert.Error(t, err) + }) +} + +func TestFilter_Statement(t *testing.T) { + f := &Filter{ + OperatorType: FilterOperatorTypeEquals, + ValueType: ValueTypeString, + Name: "env", + Value: []any{"prod"}, + } + got, err := f.Statement() + require.NoError(t, err) + assert.Equal(t, "env = 'prod'", got) +} + +func TestFilter_Alias(t *testing.T) { + f := &Filter{} + _, err := f.Alias() + assert.Error(t, err) +} + +func TestTryToParseValue(t *testing.T) { + assert.Equal(t, "'hello'", tryToParseValue("hello")) + assert.Equal(t, "42", tryToParseValue(42)) + assert.Equal(t, "3.14", tryToParseValue(3.14)) + assert.Equal(t, "true", tryToParseValue(true)) // default branch +} + +func TestFilter_ValueType_Unknown(t *testing.T) { + // Unknown value type should auto-detect + f := &Filter{ + OperatorType: FilterOperatorTypeEquals, + ValueType: ValueTypeUnknown, + Name: "x", + Value: []any{"hello"}, + } + got, err := f.Expression() + require.NoError(t, err) + assert.Equal(t, "x = 'hello'", got) + + f2 := &Filter{ + OperatorType: FilterOperatorTypeEquals, + ValueType: "", + Name: "n", + Value: []any{99}, + } + got2, err := f2.Expression() + require.NoError(t, err) + assert.Equal(t, "n = 99", got2) +} diff --git a/api/types/order_test.go b/api/types/order_test.go new file mode 100644 index 0000000..83834d0 --- /dev/null +++ b/api/types/order_test.go @@ -0,0 +1,84 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOrderBy_Expression(t *testing.T) { + tests := []struct { + name string + order *OrderBy + expected string + expectErr bool + }{ + { + name: "dimension ascending", + order: &OrderBy{ + Table: "users", + Name: "created_at", + FieldProperty: FieldPropertyDimension, + Direction: OrderDirectionTypeAscending, + }, + expected: "created_at ASC", + }, + { + name: "metric descending", + order: &OrderBy{ + Table: "orders", + Name: "total_amount", + FieldProperty: FieldPropertyMetric, + Direction: OrderDirectionTypeDescending, + }, + expected: "total_amount DESC", + }, + { + name: "unknown direction", + order: &OrderBy{ + Name: "x", + FieldProperty: FieldPropertyDimension, + Direction: OrderDirectionTypeUnknown, + }, + expectErr: true, + }, + { + name: "unsupported field property", + order: &OrderBy{ + Name: "x", + FieldProperty: "UNKNOWN", + Direction: OrderDirectionTypeAscending, + }, + expectErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.order.Expression() + if tt.expectErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, got) + } + }) + } +} + +func TestOrderBy_Statement(t *testing.T) { + o := &OrderBy{ + Name: "score", + FieldProperty: FieldPropertyMetric, + Direction: OrderDirectionTypeDescending, + } + got, err := o.Statement() + require.NoError(t, err) + assert.Equal(t, "score DESC", got) +} + +func TestOrderBy_Alias(t *testing.T) { + o := &OrderBy{} + _, err := o.Alias() + assert.Error(t, err) +} diff --git a/api/types/types_test.go b/api/types/types_test.go new file mode 100644 index 0000000..4774500 --- /dev/null +++ b/api/types/types_test.go @@ -0,0 +1,339 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ---- Dimension ---- + +func TestDimension_Expression(t *testing.T) { + tests := []struct { + name string + dim *Dimension + expected string + expectErr bool + }{ + { + name: "value type", + dim: &Dimension{Table: "orders", Name: "city", Type: DimensionTypeValue}, + expected: "orders.city", + }, + { + name: "single type", + dim: &Dimension{Table: "users", Name: "country", FieldName: "country_code", Type: DimensionTypeSingle}, + expected: "users.country_code", + }, + { + name: "expression type", + dim: &Dimension{Name: "custom", FieldName: "DATE(created_at)", Type: DimensionTypeExpression}, + expected: "DATE(created_at)", + }, + { + name: "multi type", + dim: &Dimension{ + Name: "multi", + Type: DimensionTypeMulti, + Dependency: []*Dimension{ + {Table: "t1", Name: "col", Type: DimensionTypeValue}, + }, + }, + expected: "t1.col", + }, + { + name: "multi type empty dependency", + dim: &Dimension{Name: "multi", Type: DimensionTypeMulti, Dependency: nil}, + expectErr: true, + }, + { + name: "case type", + dim: &Dimension{ + Name: "case_dim", + Type: DimensionTypeCase, + Dependency: []*Dimension{ + {Table: "t1", Name: "a", Type: DimensionTypeValue}, + {Table: "t2", Name: "b", Type: DimensionTypeValue}, + }, + }, + expected: "CASE WHEN t1.a != '' THEN t1.a WHEN t2.b != '' THEN t2.b END", + }, + { + name: "case type too few dependencies", + dim: &Dimension{Name: "c", Type: DimensionTypeCase, Dependency: []*Dimension{{Table: "t", Name: "x", Type: DimensionTypeValue}}}, + expectErr: true, + }, + { + name: "unsupported type", + dim: &Dimension{Name: "x", Type: "UNSUPPORTED"}, + expectErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.dim.Expression() + if tt.expectErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expected, got) + } + }) + } +} + +func TestDimension_Alias(t *testing.T) { + d := &Dimension{Name: "city"} + alias, err := d.Alias() + require.NoError(t, err) + assert.Equal(t, "city", alias) +} + +func TestDimension_Statement(t *testing.T) { + d := &Dimension{Table: "orders", Name: "city", Type: DimensionTypeValue} + stmt, err := d.Statement() + require.NoError(t, err) + assert.Equal(t, "orders.city AS city", stmt) +} + +// ---- Metric ---- + +func TestMetric_Expression(t *testing.T) { + tests := []struct { + name string + metric *Metric + contains string + expectErr bool + }{ + { + name: "value", + metric: &Metric{Table: "orders", Name: "status", FieldName: "status", Type: MetricTypeValue}, + contains: "`orders`.`status`", + }, + { + name: "count", + metric: &Metric{Table: "orders", Name: "cnt", FieldName: "id", Type: MetricTypeCount}, + contains: "COUNT( `orders`.`id` )", + }, + { + name: "count star", + metric: &Metric{Table: "orders", Name: "cnt", FieldName: "*", Type: MetricTypeCount}, + contains: "COUNT(*)", + }, + { + name: "distinct_count", + metric: &Metric{Table: "orders", Name: "uv", FieldName: "user_id", Type: MetricTypeDistinctCount}, + contains: "COUNT(DISTINCT", + }, + { + name: "sum", + metric: &Metric{Table: "orders", Name: "total", FieldName: "amount", Type: MetricTypeSum}, + contains: "SUM(", + }, + { + name: "expression", + metric: &Metric{Name: "custom", FieldName: "CUSTOM_FUNC(x)", Type: MetricTypeExpression}, + contains: "CUSTOM_FUNC(x)", + }, + { + name: "add", + metric: &Metric{ + Name: "ab", + Type: MetricTypeAdd, + Children: []*Metric{ + {Table: "t", Name: "a", FieldName: "a", Type: MetricTypeValue}, + {Table: "t", Name: "b", FieldName: "b", Type: MetricTypeValue}, + }, + }, + contains: "+", + }, + { + name: "subtract", + metric: &Metric{ + Name: "diff", + Type: MetricTypeSubtract, + Children: []*Metric{ + {Table: "t", Name: "a", FieldName: "a", Type: MetricTypeValue}, + {Table: "t", Name: "b", FieldName: "b", Type: MetricTypeValue}, + }, + }, + contains: "-", + }, + { + name: "multiply", + metric: &Metric{ + Name: "mul", + Type: MetricTypeMultiply, + Children: []*Metric{ + {Table: "t", Name: "a", FieldName: "a", Type: MetricTypeValue}, + {Table: "t", Name: "b", FieldName: "b", Type: MetricTypeValue}, + }, + }, + contains: "*", + }, + { + name: "divide", + metric: &Metric{ + Name: "div", + Type: MetricTypeDivide, + Children: []*Metric{ + {Table: "t", Name: "a", FieldName: "a", Type: MetricTypeValue}, + {Table: "t", Name: "b", FieldName: "b", Type: MetricTypeValue}, + }, + }, + contains: "NULLIF", + }, + { + name: "as", + metric: &Metric{ + Name: "as_m", + Type: MetricTypeAs, + Children: []*Metric{ + {Table: "t", Name: "a", FieldName: "a", Type: MetricTypeValue}, + }, + }, + contains: "`t`.`a`", + }, + { + name: "unknown type", + metric: &Metric{Name: "x", Type: MetricTypeUnknown}, + expectErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.metric.Expression() + if tt.expectErr { + assert.Error(t, err) + } else { + require.NoError(t, err) + assert.Contains(t, got, tt.contains) + } + }) + } +} + +func TestMetric_Alias(t *testing.T) { + m := &Metric{Table: "orders", Name: "revenue", FieldName: "amount", Type: MetricTypeSum} + alias, err := m.Alias() + require.NoError(t, err) + assert.Equal(t, "revenue", alias) +} + +func TestMetric_Statement(t *testing.T) { + m := &Metric{Table: "orders", Name: "cnt", FieldName: "*", Type: MetricTypeCount} + stmt, err := m.Statement() + require.NoError(t, err) + assert.Contains(t, stmt, "AS cnt") +} + +// ---- Join ---- + +func TestJoin_GetJoinType(t *testing.T) { + t.Run("explicit join type", func(t *testing.T) { + j := &Join{JoinType: "INNER JOIN"} + assert.Equal(t, "INNER JOIN", j.GetJoinType()) + }) + + t.Run("default left join", func(t *testing.T) { + j := &Join{} + assert.Equal(t, "LEFT JOIN", j.GetJoinType()) + }) +} + +// ---- DataSource (types package) ---- + +func TestTypesDataSource_Expression(t *testing.T) { + t.Run("with database", func(t *testing.T) { + ds := &DataSource{Database: "mydb", Name: "orders"} + expr, err := ds.Expression() + require.NoError(t, err) + assert.Equal(t, "`mydb`.`orders`", expr) + }) + + t.Run("without database", func(t *testing.T) { + ds := &DataSource{Name: "orders"} + expr, err := ds.Expression() + require.NoError(t, err) + assert.Equal(t, "`orders`", expr) + }) +} + +func TestTypesDataSource_Alias(t *testing.T) { + t.Run("with alias", func(t *testing.T) { + ds := &DataSource{Name: "orders", AliasName: "o"} + alias, err := ds.Alias() + require.NoError(t, err) + assert.Equal(t, "o", alias) + }) + + t.Run("without alias defaults to name", func(t *testing.T) { + ds := &DataSource{Name: "orders"} + alias, err := ds.Alias() + require.NoError(t, err) + assert.Equal(t, "orders", alias) + }) +} + +func TestTypesDataSource_Statement(t *testing.T) { + ds := &DataSource{Database: "db", Name: "orders"} + stmt, err := ds.Statement() + require.NoError(t, err) + assert.Equal(t, "`db`.`orders` AS orders", stmt) +} + +// ---- Query / TimeInterval ---- + +func TestTimeInterval_ToFilter(t *testing.T) { + ti := &TimeInterval{Name: "created_at", Start: "2026-01-01", End: "2026-02-01"} + f := ti.ToFilter() + assert.Equal(t, FilterOperatorTypeAnd, f.OperatorType) + assert.Len(t, f.Children, 2) + assert.Equal(t, FilterOperatorTypeGreaterEquals, f.Children[0].OperatorType) + assert.Equal(t, FilterOperatorTypeLess, f.Children[1].OperatorType) +} + +func TestQuery_TranslateTimeIntervalToFilter(t *testing.T) { + t.Run("with valid time interval", func(t *testing.T) { + q := &Query{ + TimeInterval: &TimeInterval{Name: "ts", Start: "2026-01-01", End: "2026-02-01"}, + } + q.TranslateTimeIntervalToFilter() + assert.Len(t, q.Filters, 1) + }) + + t.Run("with nil time interval", func(t *testing.T) { + q := &Query{} + q.TranslateTimeIntervalToFilter() + assert.Empty(t, q.Filters) + }) + + t.Run("with empty start/end", func(t *testing.T) { + q := &Query{ + TimeInterval: &TimeInterval{Name: "ts"}, + } + q.TranslateTimeIntervalToFilter() + assert.Empty(t, q.Filters) + }) +} + +// ---- Result ---- + +func TestResult_SetDimensions(t *testing.T) { + r := &Result{} + q := &Query{ + Dimensions: []string{"city", "country"}, + Metrics: []string{"revenue", "cnt"}, + } + r.SetDimensions(q) + assert.Equal(t, []string{"city", "country", "revenue", "cnt"}, r.Dimensions) +} + +func TestResult_AddSource(t *testing.T) { + r := &Result{} + err := r.AddSource(map[string]any{"city": "beijing", "revenue": 100}) + require.NoError(t, err) + assert.Len(t, r.Source, 1) + assert.Equal(t, "beijing", r.Source[0]["city"]) +}