diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 5f7361160b3..68887b0f670 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -111,6 +111,15 @@ fromClause relation : tableName (AS? alias)? # tableAsRelation | LR_BRACKET subquery = querySpecification RR_BRACKET AS? alias # subqueryAsRelation + | qualifiedName LR_BRACKET tableFunctionArgs RR_BRACKET AS? alias # tableFunctionRelation + ; + +tableFunctionArgs + : tableFunctionArg (COMMA tableFunctionArg)* + ; + +tableFunctionArg + : ident EQUAL_SYMBOL functionArg ; whereClause diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java index bdbc360713c..bacdf9d7378 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstBuilder.java @@ -13,6 +13,7 @@ import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.SelectElementContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.SubqueryAsRelationContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.TableAsRelationContext; +import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.TableFunctionRelationContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.WhereClauseContext; import static org.opensearch.sql.sql.parser.ParserUtils.getTextInQuery; import static org.opensearch.sql.utils.SystemIndexUtils.TABLE_INFO; @@ -20,12 +21,14 @@ import com.google.common.collect.ImmutableList; import java.util.Collections; +import java.util.Locale; import java.util.Optional; import lombok.RequiredArgsConstructor; import org.antlr.v4.runtime.tree.ParseTree; import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.DescribeRelation; import org.opensearch.sql.ast.tree.Filter; @@ -34,6 +37,7 @@ import org.opensearch.sql.ast.tree.Relation; import org.opensearch.sql.ast.tree.RelationSubquery; import org.opensearch.sql.ast.tree.SubqueryAlias; +import org.opensearch.sql.ast.tree.TableFunction; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.common.antlr.SyntaxCheckException; @@ -189,6 +193,24 @@ public UnresolvedPlan visitSubqueryAsRelation(SubqueryAsRelationContext ctx) { return new RelationSubquery(visit(ctx.subquery), subqueryAlias); } + @Override + public UnresolvedPlan visitTableFunctionRelation(TableFunctionRelationContext ctx) { + ImmutableList.Builder args = ImmutableList.builder(); + ctx.tableFunctionArgs() + .tableFunctionArg() + .forEach( + arg -> { + String argName = + StringUtils.unquoteIdentifier(arg.ident().getText()).toLowerCase(Locale.ROOT); + UnresolvedExpression argValue = visitAstExpression(arg.functionArg()); + args.add(new UnresolvedArgument(argName, argValue)); + }); + TableFunction tableFunction = + new TableFunction(visitAstExpression(ctx.qualifiedName()), args.build()); + String alias = StringUtils.unquoteIdentifier(ctx.alias().getText()); + return new SubqueryAlias(alias, tableFunction); + } + @Override public UnresolvedPlan visitWhereClause(WhereClauseContext ctx) { return new Filter(visitAstExpression(ctx.expression())); diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java index 1ecaa181e6f..da48b9bb7fc 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java @@ -40,6 +40,9 @@ import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.NestedAllTupleFields; +import org.opensearch.sql.ast.expression.UnresolvedArgument; +import org.opensearch.sql.ast.tree.SubqueryAlias; +import org.opensearch.sql.ast.tree.TableFunction; import org.opensearch.sql.common.antlr.SyntaxCheckException; class AstBuilderTest extends AstBuilderTestBase { @@ -131,6 +134,123 @@ public void can_build_from_index_with_alias_quoted() { buildAST("SELECT `t`.name FROM test `t` WHERE `t`.age = 30")); } + @Test + public void can_build_from_table_function() { + assertEquals( + project( + new SubqueryAlias( + "v", + new TableFunction( + qualifiedName("vectorSearch"), + ImmutableList.of( + new UnresolvedArgument("table", stringLiteral("products")), + new UnresolvedArgument("field", stringLiteral("embedding")), + new UnresolvedArgument("vector", stringLiteral("[0.1,0.2]")), + new UnresolvedArgument("option", stringLiteral("k=10"))))), + AllFields.of()), + buildAST( + "SELECT * FROM vectorSearch(" + + "table='products', field='embedding', " + + "vector='[0.1,0.2]', option='k=10') AS v")); + } + + @Test + public void can_build_from_table_function_with_where_order_limit() { + assertEquals( + project( + limit( + sort( + filter( + new SubqueryAlias( + "s", + new TableFunction( + qualifiedName("vectorSearch"), + ImmutableList.of( + new UnresolvedArgument("table", stringLiteral("products")), + new UnresolvedArgument("field", stringLiteral("embedding")), + new UnresolvedArgument("vector", stringLiteral("[0.1,0.2]")), + new UnresolvedArgument("option", stringLiteral("k=10"))))), + function("=", qualifiedName("s", "category"), stringLiteral("shoes"))), + field(qualifiedName("s", "_score"), argument("asc", booleanLiteral(false)))), + 5, + 0), + alias("s.title", qualifiedName("s", "title")), + alias("s._score", qualifiedName("s", "_score"))), + buildAST( + "SELECT s.title, s._score FROM vectorSearch(" + + "table='products', field='embedding', " + + "vector='[0.1,0.2]', option='k=10') AS s " + + "WHERE s.category = 'shoes' " + + "ORDER BY s._score DESC " + + "LIMIT 5")); + } + + @Test + public void table_function_args_are_resolved_by_name_not_position() { + assertEquals( + project( + new SubqueryAlias( + "v", + new TableFunction( + qualifiedName("vectorSearch"), + ImmutableList.of( + new UnresolvedArgument("option", stringLiteral("k=10")), + new UnresolvedArgument("field", stringLiteral("embedding")), + new UnresolvedArgument("table", stringLiteral("products")), + new UnresolvedArgument("vector", stringLiteral("[0.1,0.2]"))))), + AllFields.of()), + buildAST( + "SELECT * FROM vectorSearch(" + + "option='k=10', field='embedding', " + + "table='products', vector='[0.1,0.2]') AS v")); + } + + @Test + public void table_function_arg_names_are_canonicalized() { + assertEquals( + project( + new SubqueryAlias( + "v", + new TableFunction( + qualifiedName("vectorSearch"), + ImmutableList.of( + new UnresolvedArgument("table", stringLiteral("products")), + new UnresolvedArgument("field", stringLiteral("embedding")), + new UnresolvedArgument("vector", stringLiteral("[0.1,0.2]")), + new UnresolvedArgument("option", stringLiteral("k=10"))))), + AllFields.of()), + buildAST( + "SELECT * FROM vectorSearch(" + + "TABLE='products', FIELD='embedding', " + + "VECTOR='[0.1,0.2]', OPTION='k=10') AS v")); + } + + @Test + public void table_function_allows_alias_without_as_keyword() { + assertEquals( + project( + new SubqueryAlias( + "v", + new TableFunction( + qualifiedName("vectorSearch"), + ImmutableList.of( + new UnresolvedArgument("table", stringLiteral("products")), + new UnresolvedArgument("vector", stringLiteral("[0.1]"))))), + AllFields.of()), + buildAST("SELECT * FROM vectorSearch(table='products', vector='[0.1]') v")); + } + + @Test + public void table_function_relation_requires_alias() { + assertThrows( + SyntaxCheckException.class, + () -> + buildAST( + "SELECT * FROM vectorSearch(" + + "table='products', field='embedding', " + + "vector='[0.1,0.2]', option='k=10')")); + } + @Test public void can_build_where_clause() { assertEquals(