diff --git a/common/src/main/java/org/opensearch/sql/common/cluster/TextSimilarityClustering.java b/common/src/main/java/org/opensearch/sql/common/cluster/TextSimilarityClustering.java new file mode 100644 index 00000000000..17155989064 --- /dev/null +++ b/common/src/main/java/org/opensearch/sql/common/cluster/TextSimilarityClustering.java @@ -0,0 +1,169 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.cluster; + +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import org.apache.commons.text.similarity.CosineSimilarity; + +/** + * Greedy single-pass text similarity clustering for grouping similar text values. Events are + * processed in order; each is compared to existing cluster representatives using cosine similarity. + * If the best match meets the threshold, the event joins that cluster; otherwise a new cluster is + * created. + * + *

Optimized for incremental processing with vector caching and memory-efficient operations. + */ +public class TextSimilarityClustering { + + private static final CosineSimilarity COSINE = new CosineSimilarity(); + + // Cache vectorized representations to avoid recomputation + private final Map> vectorCache = + new LinkedHashMap<>(MAX_CACHE_SIZE, 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry> eldest) { + return size() > MAX_CACHE_SIZE; + } + }; + private static final int MAX_CACHE_SIZE = 10000; + + private final double threshold; + private final String matchMode; + private final String delims; + + public TextSimilarityClustering(double threshold, String matchMode, String delims) { + this.threshold = validateThreshold(threshold); + this.matchMode = validateMatchMode(matchMode); + this.delims = delims != null ? delims : " "; + } + + private static double validateThreshold(double threshold) { + if (threshold <= 0.0 || threshold >= 1.0) { + throw new IllegalArgumentException( + "The threshold must be > 0.0 and < 1.0, got: " + threshold); + } + return threshold; + } + + private static String validateMatchMode(String matchMode) { + if (matchMode == null) { + return "termlist"; + } + switch (matchMode.toLowerCase()) { + case "termlist": + case "termset": + case "ngramset": + return matchMode.toLowerCase(); + default: + throw new IllegalArgumentException( + "Invalid match mode: " + matchMode + ". Must be one of: termlist, termset, ngramset"); + } + } + + /** + * Compute similarity between two text values using the configured match mode. Used for + * incremental clustering against cluster representatives. + */ + public double computeSimilarity(String text1, String text2) { + // Normalize nulls to empty strings + String normalizedText1 = (text1 == null) ? "" : text1; + String normalizedText2 = (text2 == null) ? "" : text2; + + // Both are empty - perfect match + if (normalizedText1.isEmpty() && normalizedText2.isEmpty()) { + return 1.0; + } + + // One is empty, other isn't - no match + if (normalizedText1.isEmpty() || normalizedText2.isEmpty()) { + return 0.0; + } + + // Both non-empty - compute cosine similarity + Map vector1 = vectorizeWithCache(normalizedText1); + Map vector2 = vectorizeWithCache(normalizedText2); + + return COSINE.cosineSimilarity(vector1, vector2); + } + + private Map vectorizeWithCache(String value) { + return vectorCache.computeIfAbsent(value, this::vectorize); + } + + private Map vectorize(String value) { + if (value == null || value.isEmpty()) { + return Map.of(); + } + return switch (matchMode) { + case "termset" -> vectorizeTermSet(value); + case "ngramset" -> vectorizeNgramSet(value); + default -> vectorizeTermList(value); + }; + } + + private static final java.util.regex.Pattern NUMERIC_PATTERN = + java.util.regex.Pattern.compile("^\\d+$"); + + private static String normalizeToken(String token) { + return NUMERIC_PATTERN.matcher(token).matches() ? "*" : token; + } + + /** Positional term frequency — token order matters. */ + private Map vectorizeTermList(String value) { + String[] tokens = tokenize(value); + Map vector = new HashMap<>((int) (tokens.length * 1.4)); + + for (int i = 0; i < tokens.length; i++) { + if (!tokens[i].isEmpty()) { + String key = i + "-" + normalizeToken(tokens[i]); + vector.merge(key, 1, Integer::sum); + } + } + return vector; + } + + /** Bag-of-words term frequency — token order ignored. */ + private Map vectorizeTermSet(String value) { + String[] tokens = tokenize(value); + Map vector = new HashMap<>((int) (tokens.length * 1.4)); + + for (String token : tokens) { + if (!token.isEmpty()) { + vector.merge(normalizeToken(token), 1, Integer::sum); + } + } + return vector; + } + + /** Character trigram frequency. */ + private Map vectorizeNgramSet(String value) { + if (value.length() < 3) { + // For very short strings, fall back to character frequency + Map vector = new HashMap<>(); + for (char c : value.toCharArray()) { + vector.merge(String.valueOf(c), 1, Integer::sum); + } + return vector; + } + + Map vector = new HashMap<>((int) ((value.length() - 2) * 1.4)); + for (int i = 0; i <= value.length() - 3; i++) { + String ngram = value.substring(i, i + 3); + vector.merge(ngram, 1, Integer::sum); + } + return vector; + } + + private String[] tokenize(String value) { + if ("non-alphanumeric".equals(delims)) { + return value.split("[^a-zA-Z0-9_]+"); + } + String pattern = "[" + java.util.regex.Pattern.quote(delims) + "]+"; + return value.split(pattern); + } +} diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 7f02bb3ef1b..bd7a6967e4a 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -54,6 +54,7 @@ import org.opensearch.sql.ast.tree.Bin; import org.opensearch.sql.ast.tree.Chart; import org.opensearch.sql.ast.tree.CloseCursor; +import org.opensearch.sql.ast.tree.Cluster; import org.opensearch.sql.ast.tree.Convert; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ast.tree.Eval; @@ -432,6 +433,10 @@ public T visitPatterns(Patterns patterns, C context) { return visitChildren(patterns, context); } + public T visitCluster(Cluster node, C context) { + return visitChildren(node, context); + } + public T visitWindow(Window window, C context) { return visitChildren(window, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Cluster.java b/core/src/main/java/org/opensearch/sql/ast/tree/Cluster.java new file mode 100644 index 00000000000..0d1904c2838 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Cluster.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Setter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +/** AST node for the PPL cluster command. */ +@Getter +@Setter +@ToString +@EqualsAndHashCode(callSuper = false) +@RequiredArgsConstructor +@AllArgsConstructor +public class Cluster extends UnresolvedPlan { + + private final UnresolvedExpression sourceField; + private final double threshold; + private final String matchMode; + private final String labelField; + private final String countField; + private final boolean labelOnly; + private final boolean showCount; + private final String delims; + private UnresolvedPlan child; + + @Override + public Cluster attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return this.child == null ? ImmutableList.of() : ImmutableList.of(this.child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitCluster(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java index ed68dfbcb1b..bb57523aa91 100644 --- a/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/calcite/CalciteRelNodeVisitor.java @@ -2478,6 +2478,98 @@ public RelNode visitKmeans(Kmeans node, CalcitePlanContext context) { throw new CalciteUnsupportedException("Kmeans command is unsupported in Calcite"); } + @Override + public RelNode visitCluster( + org.opensearch.sql.ast.tree.Cluster node, CalcitePlanContext context) { + visitChildren(node, context); + + // Filter out rows where the source field is null before clustering. + RexNode sourceFieldRex = rexVisitor.analyze(node.getSourceField(), context); + context.relBuilder.filter( + context.rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, sourceFieldRex)); + + // Resolve clustering as a window function over all rows (unbounded frame). + // The window function buffers all rows, runs the greedy clustering algorithm, + // and returns an array of cluster labels (one per input row, in order). + List funcParams = new ArrayList<>(); + funcParams.add(node.getSourceField()); + funcParams.add(AstDSL.doubleLiteral(node.getThreshold())); + funcParams.add(AstDSL.stringLiteral(node.getMatchMode())); + funcParams.add(AstDSL.stringLiteral(node.getDelims())); + + RexNode clusterWindow = + rexVisitor.analyze( + new WindowFunction( + new Function( + BuiltinFunctionName.INTERNAL_CLUSTER_LABEL.getName().getFunctionName(), + funcParams), + List.of(), + List.of()), + context); + String arrayAlias = "_cluster_labels_array"; + context.relBuilder.projectPlus(context.relBuilder.alias(clusterWindow, arrayAlias)); + + // Add ROW_NUMBER to index into the array (1-based). + String rowNumAlias = "_cluster_row_idx"; + RexNode rowNum = + context + .relBuilder + .aggregateCall(SqlStdOperatorTable.ROW_NUMBER) + .over() + .rowsBetween(RexWindowBounds.UNBOUNDED_PRECEDING, RexWindowBounds.CURRENT_ROW) + .as(rowNumAlias); + context.relBuilder.projectPlus(rowNum); + + // Extract the label for this row: array[row_number] (ITEM access is 1-based). + RexNode rowIdxAsInt = + context.rexBuilder.makeCast( + context.rexBuilder.getTypeFactory().createSqlType(SqlTypeName.INTEGER), + context.relBuilder.field(rowNumAlias)); + RexNode labelExpr = + context.rexBuilder.makeCall( + SqlStdOperatorTable.ITEM, context.relBuilder.field(arrayAlias), rowIdxAsInt); + context.relBuilder.projectPlus(context.relBuilder.alias(labelExpr, node.getLabelField())); + + // Remove the temporary array and row index columns. + context.relBuilder.projectExcept( + context.relBuilder.field(arrayAlias), context.relBuilder.field(rowNumAlias)); + + if (node.isShowCount()) { + // cluster_count = COUNT(*) OVER (PARTITION BY cluster_label) + RexNode countWindow = + context + .relBuilder + .aggregateCall(SqlStdOperatorTable.COUNT) + .over() + .partitionBy(context.relBuilder.field(node.getLabelField())) + .rowsBetween(RexWindowBounds.UNBOUNDED_PRECEDING, RexWindowBounds.UNBOUNDED_FOLLOWING) + .as(node.getCountField()); + context.relBuilder.projectPlus(countWindow); + } + + if (!node.isLabelOnly()) { + // Filter to representative rows only: keep the first event per cluster. + String convergenceRowNum = "_cluster_convergence_row_num"; + RexNode convergenceRn = + context + .relBuilder + .aggregateCall(SqlStdOperatorTable.ROW_NUMBER) + .over() + .partitionBy(context.relBuilder.field(node.getLabelField())) + .rowsTo(RexWindowBounds.CURRENT_ROW) + .as(convergenceRowNum); + context.relBuilder.projectPlus(convergenceRn); + context.relBuilder.filter( + context.rexBuilder.makeCall( + SqlStdOperatorTable.EQUALS, + context.relBuilder.field(convergenceRowNum), + context.rexBuilder.makeExactLiteral(java.math.BigDecimal.ONE))); + context.relBuilder.projectExcept(context.relBuilder.field(convergenceRowNum)); + } + + return context.relBuilder.peek(); + } + @Override public RelNode visitRareTopN(RareTopN node, CalcitePlanContext context) { visitChildren(node, context); diff --git a/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/ClusterLabelAggFunction.java b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/ClusterLabelAggFunction.java new file mode 100644 index 00000000000..ae88879893c --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/calcite/udf/udaf/ClusterLabelAggFunction.java @@ -0,0 +1,221 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.udf.udaf; + +import java.util.ArrayList; +import java.util.List; +import org.opensearch.sql.calcite.udf.UserDefinedAggFunction; +import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.common.cluster.TextSimilarityClustering; + +/** + * Aggregate function for the cluster command. Uses buffered processing similar to + * LogPatternAggFunction to handle large datasets efficiently. Processes events in configurable + * batches to avoid memory issues. + * + *

When used as a window function over an unbounded frame, the result is a List where each + * element corresponds to the cluster label for that row position. + */ +public class ClusterLabelAggFunction + implements UserDefinedAggFunction { + + private int bufferLimit = 50000; // Configurable buffer size + private int maxClusters = 10000; // Limit cluster count to prevent memory explosion + private double threshold = 0.8; + private String matchMode = "termlist"; + private String delims = " "; + + @Override + public Acc init() { + return new Acc(); + } + + @Override + public Object result(Acc acc) { + return acc.value(threshold, matchMode, delims, maxClusters); + } + + @Override + public Acc add(Acc acc, Object... values) { + // Handle case where Calcite calls generic method with null field value + if (values.length == 1) { + String field = (values[0] != null) ? values[0].toString() : null; + return add(acc, field); + } + + throw new SyntaxCheckException( + "Unsupported function signature for cluster aggregate. Valid parameters include (field:" + + " required string), (t: optional double threshold 0.0-1.0, default 0.8), (match:" + + " optional string algorithm 'termlist'|'termset'|'ngramset', default 'termlist')," + + " (delims: optional string delimiters, default ' '), (labelfield: optional string" + + " output field name, default 'cluster_label'), (countfield: optional string count" + + " field name, default 'cluster_count')"); + } + + public Acc add( + Acc acc, + String field, + double threshold, + String matchMode, + String delims, + int bufferLimit, + int maxClusters) { + // Process all rows, even when field is null - convert null to empty string + // This ensures the result array matches input row count + String processedField = (field != null) ? field : ""; + + this.threshold = threshold; + this.matchMode = matchMode; + this.delims = delims; + this.bufferLimit = bufferLimit; + this.maxClusters = maxClusters; + + acc.evaluate(processedField); + + if (bufferLimit > 0 && acc.bufferSize() == bufferLimit) { + acc.partialMerge(threshold, matchMode, delims, maxClusters); + acc.clearBuffer(); + } + + return acc; + } + + public Acc add(Acc acc, String field, double threshold, String matchMode, String delims) { + return add(acc, field, threshold, matchMode, delims, this.bufferLimit, this.maxClusters); + } + + public Acc add(Acc acc, String field, double threshold, String matchMode) { + return add(acc, field, threshold, matchMode, this.delims, this.bufferLimit, this.maxClusters); + } + + public Acc add(Acc acc, String field, double threshold) { + return add( + acc, field, threshold, this.matchMode, this.delims, this.bufferLimit, this.maxClusters); + } + + public Acc add(Acc acc, String field) { + return add( + acc, + field, + this.threshold, + this.matchMode, + this.delims, + this.bufferLimit, + this.maxClusters); + } + + public static class Acc implements Accumulator { + private final List buffer = new ArrayList<>(); + private final List globalClusters = new ArrayList<>(); + private final List allLabels = new ArrayList<>(); + private int nextClusterId = 1; + + public int bufferSize() { + return buffer.size(); + } + + public void evaluate(String value) { + buffer.add(value != null ? value : ""); + } + + public void partialMerge(Object... argList) { + if (buffer.isEmpty()) { + return; + } + + double threshold = (Double) argList[0]; + String matchMode = (String) argList[1]; + String delims = (String) argList[2]; + int maxClusters = (Integer) argList[3]; + + TextSimilarityClustering clustering = + new TextSimilarityClustering(threshold, matchMode, delims); + + for (String value : buffer) { + ClusterAssignment assignment = + findOrCreateCluster(value, clustering, threshold, maxClusters); + allLabels.add(assignment.clusterId); + } + } + + private ClusterAssignment findOrCreateCluster( + String value, TextSimilarityClustering clustering, double threshold, int maxClusters) { + double bestSimilarity = -1.0; + ClusterRepresentative bestCluster = null; + + // Compare against existing global clusters + for (ClusterRepresentative cluster : globalClusters) { + double similarity = clustering.computeSimilarity(value, cluster.representative); + if (similarity > bestSimilarity) { + bestSimilarity = similarity; + bestCluster = cluster; + } + } + + if (bestSimilarity >= threshold - 1e-9 && bestCluster != null) { + // Join existing cluster + bestCluster.size++; + return new ClusterAssignment(bestCluster.id, bestCluster.size); + } else if (globalClusters.size() < maxClusters) { + // Create new cluster + ClusterRepresentative newCluster = new ClusterRepresentative(nextClusterId++, value, 1); + globalClusters.add(newCluster); + return new ClusterAssignment(newCluster.id, 1); + } else { + // Force into closest existing cluster when at max limit + if (bestCluster != null) { + bestCluster.size++; + return new ClusterAssignment(bestCluster.id, bestCluster.size); + } else if (!globalClusters.isEmpty()) { + // Fallback: assign to cluster 1 + globalClusters.get(0).size++; + return new ClusterAssignment(globalClusters.get(0).id, globalClusters.get(0).size); + } else { + // Emergency fallback: create first cluster + ClusterRepresentative newCluster = new ClusterRepresentative(1, value, 1); + globalClusters.add(newCluster); + nextClusterId = 2; + return new ClusterAssignment(1, 1); + } + } + } + + public void clearBuffer() { + buffer.clear(); + } + + @Override + public Object value(Object... argList) { + partialMerge(argList); + clearBuffer(); + return new ArrayList<>(allLabels); + } + + /** Represents a cluster with its representative text and current size */ + private static class ClusterRepresentative { + final int id; + final String representative; + int size; + + ClusterRepresentative(int id, String representative, int size) { + this.id = id; + this.representative = representative; + this.size = size; + } + } + + /** Result of cluster assignment */ + private static class ClusterAssignment { + final int clusterId; + final int clusterSize; + + ClusterAssignment(int clusterId, int clusterSize) { + this.clusterId = clusterId; + this.clusterSize = clusterSize; + } + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 14f058a75d0..c57d4b23b8f 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -357,6 +357,7 @@ public enum BuiltinFunctionName { INTERNAL_ITEM(FunctionName.of("item"), true), INTERNAL_PATTERN_PARSER(FunctionName.of("pattern_parser")), INTERNAL_PATTERN(FunctionName.of("pattern")), + INTERNAL_CLUSTER_LABEL(FunctionName.of("cluster_label")), INTERNAL_UNCOLLECT_PATTERNS(FunctionName.of("uncollect_patterns")), INTERNAL_GROK(FunctionName.of("grok"), true), INTERNAL_PARSE(FunctionName.of("parse"), true), @@ -425,6 +426,7 @@ public enum BuiltinFunctionName { .put("dc", BuiltinFunctionName.DISTINCT_COUNT_APPROX) .put("distinct_count", BuiltinFunctionName.DISTINCT_COUNT_APPROX) .put("pattern", BuiltinFunctionName.INTERNAL_PATTERN) + .put("cluster_label", BuiltinFunctionName.INTERNAL_CLUSTER_LABEL) .build(); public static Optional of(String str) { diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java index 2aebf7efe34..20c94641865 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLBuiltinOperators.java @@ -26,9 +26,12 @@ import org.apache.calcite.sql.SqlKind; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.calcite.sql.type.SqlTypeName; import org.apache.calcite.sql.type.SqlTypeTransforms; +import org.apache.calcite.sql.type.SqlTypeUtil; import org.apache.calcite.sql.util.ReflectiveSqlOperatorTable; import org.apache.calcite.util.BuiltInMethod; +import org.opensearch.sql.calcite.udf.udaf.ClusterLabelAggFunction; import org.opensearch.sql.calcite.udf.udaf.FirstAggFunction; import org.opensearch.sql.calcite.udf.udaf.LastAggFunction; import org.opensearch.sql.calcite.udf.udaf.ListAggFunction; @@ -483,6 +486,16 @@ public class PPLBuiltinOperators extends ReflectiveSqlOperatorTable { "pattern", ReturnTypes.explicit(UserDefinedFunctionUtils.nullablePatternAggList), null); + public static final SqlAggFunction CLUSTER_LABEL = + createUserDefinedAggFunction( + ClusterLabelAggFunction.class, + "cluster_label", + opBinding -> + SqlTypeUtil.createArrayType( + opBinding.getTypeFactory(), + opBinding.getTypeFactory().createSqlType(SqlTypeName.INTEGER), + true), + null); public static final SqlAggFunction LIST = createUserDefinedAggFunction( ListAggFunction.class, "LIST", PPLReturnTypes.STRING_ARRAY, PPLOperandTypes.ANY_SCALAR); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java index 30d7c055470..48b8176466c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/PPLFuncImpTable.java @@ -82,6 +82,7 @@ import static org.opensearch.sql.expression.function.BuiltinFunctionName.IF; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IFNULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.ILIKE; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_CLUSTER_LABEL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_GROK; import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_ITEM; import static org.opensearch.sql.expression.function.BuiltinFunctionName.INTERNAL_PARSE; @@ -1349,6 +1350,7 @@ void populate() { registerOperator(STDDEV_POP, PPLBuiltinOperators.STDDEV_POP_NULLABLE); registerOperator(TAKE, PPLBuiltinOperators.TAKE); registerOperator(INTERNAL_PATTERN, PPLBuiltinOperators.INTERNAL_PATTERN); + registerOperator(INTERNAL_CLUSTER_LABEL, PPLBuiltinOperators.CLUSTER_LABEL); registerOperator(LIST, PPLBuiltinOperators.LIST); registerOperator(VALUES, PPLBuiltinOperators.VALUES); diff --git a/docs/category.json b/docs/category.json index 5e9b6f954a5..5c8894923ed 100644 --- a/docs/category.json +++ b/docs/category.json @@ -44,6 +44,7 @@ "user/ppl/cmd/subquery.md", "user/ppl/cmd/syntax.md", "user/ppl/cmd/chart.md", + "user/ppl/cmd/cluster.md", "user/ppl/cmd/timechart.md", "user/ppl/cmd/top.md", "user/ppl/cmd/trendline.md", diff --git a/docs/user/ppl/cmd/cluster.md b/docs/user/ppl/cmd/cluster.md new file mode 100644 index 00000000000..dc0fca30a91 --- /dev/null +++ b/docs/user/ppl/cmd/cluster.md @@ -0,0 +1,211 @@ +# cluster + +The `cluster` command groups documents into clusters based on text similarity using various clustering algorithms. Documents with similar text content are assigned to the same cluster and receive matching `cluster_label` values. Rows where the source field is null are excluded from the results. + +## Syntax + +The `cluster` command has the following syntax: + +```syntax +cluster [t=] [match=] [labelfield=] [countfield=] [showcount=] [labelonly=] [delims=] +``` + +## Parameters + +The `cluster` command supports the following parameters. + +| Parameter | Required/Optional | Description | +| --- | --- | --- | +| `` | Required | The text field to use for clustering analysis. | +| `t` | Optional | Similarity threshold between 0.0 and 1.0. Documents with similarity above this threshold are grouped together. Default is `0.8`. | +| `match` | Optional | Clustering algorithm to use. Valid values are `termlist`, `termset`, `ngramset`. Default is `termlist`. | +| `labelfield` | Optional | Name of the field to store the cluster label. Default is `cluster_label`. | +| `countfield` | Optional | Name of the field to store the cluster size. Default is `cluster_count`. | +| `showcount` | Optional | Whether to include the cluster count field in the output. Default is `false`. | +| `labelonly` | Optional | When `true`, keeps all rows and only adds the cluster label. When `false` (default), deduplicates by keeping only the first representative row per cluster. Default is `false`. | +| `delims` | Optional | Delimiter characters used for tokenization. Default is `non-alphanumeric` (splits on any non-alphanumeric character). | + + +## Example 1: Basic text clustering + +The following query groups log messages by similarity using default settings: + +```ppl +source=otellogs +| cluster body +| fields body, cluster_label +| head 4 +``` + +The query returns the following results: + +```text +fetched rows / total rows = 4/4 ++--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+ +| body | cluster_label | +|--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------| +| User e1ce63e6-8501-11f0-930d-c2fcbdc05f14 adding 4 of product HQTGWGPNH4 to cart | 1 | +| Payment failed: Insufficient funds for user@example.com | 2 | +| Query contains Lucene special characters: +field:value -excluded AND (grouped OR terms) NOT "exact phrase" wildcard* fuzzy~2 /regex/ [range TO search] | 3 | +| 192.168.1.1 - - [15/Jan/2024:10:30:03 +0000] "GET /api/products?search=laptop&category=electronics HTTP/1.1" 200 1234 "-" "Mozilla/5.0" | 4 | ++--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+ +``` + + +## Example 2: Clustering with showcount + +The following query uses the `termset` algorithm with a lower threshold to group more messages together, and includes the cluster count: + +```ppl +source=otellogs +| cluster body match=termset t=0.3 showcount=true +| fields body, cluster_label, cluster_count +| head 5 +``` + +The query returns the following results: + +```text +fetched rows / total rows = 5/5 ++--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------+ +| body | cluster_label | cluster_count | +|--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------| +| User e1ce63e6-8501-11f0-930d-c2fcbdc05f14 adding 4 of product HQTGWGPNH4 to cart | 1 | 3 | +| Payment failed: Insufficient funds for user@example.com | 2 | 3 | +| Query contains Lucene special characters: +field:value -excluded AND (grouped OR terms) NOT "exact phrase" wildcard* fuzzy~2 /regex/ [range TO search] | 3 | 1 | +| Email notification sent to john.doe+newsletter@company.com with subject: 'Welcome! Your order #12345 is confirmed' | 4 | 1 | +| Database connection pool exhausted: postgresql://db.example.com:5432/production | 5 | 1 | ++--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------+ +``` + + +## Example 3: Custom similarity threshold + +The following query uses a higher similarity threshold to create more distinct clusters: + +```ppl +source=otellogs +| cluster body t=0.9 showcount=true +| fields body, cluster_label, cluster_count +| head 5 +``` + +The query returns the following results: + +```text +fetched rows / total rows = 5/5 ++--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------+ +| body | cluster_label | cluster_count | +|--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------| +| User e1ce63e6-8501-11f0-930d-c2fcbdc05f14 adding 4 of product HQTGWGPNH4 to cart | 1 | 1 | +| Payment failed: Insufficient funds for user@example.com | 2 | 1 | +| Query contains Lucene special characters: +field:value -excluded AND (grouped OR terms) NOT "exact phrase" wildcard* fuzzy~2 /regex/ [range TO search] | 3 | 1 | +| 192.168.1.1 - - [15/Jan/2024:10:30:03 +0000] "GET /api/products?search=laptop&category=electronics HTTP/1.1" 200 1234 "-" "Mozilla/5.0" | 4 | 1 | +| Email notification sent to john.doe+newsletter@company.com with subject: 'Welcome! Your order #12345 is confirmed' | 5 | 1 | ++--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------+ +``` + + +## Example 4: Clustering with termset algorithm + +The following query uses the `termset` algorithm which ignores word order when comparing text: + +```ppl +source=otellogs +| cluster body match=termset showcount=true +| fields body, cluster_label, cluster_count +| head 4 +``` + +The query returns the following results: + +```text +fetched rows / total rows = 4/4 ++--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------+ +| body | cluster_label | cluster_count | +|--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------| +| User e1ce63e6-8501-11f0-930d-c2fcbdc05f14 adding 4 of product HQTGWGPNH4 to cart | 1 | 1 | +| Payment failed: Insufficient funds for user@example.com | 2 | 1 | +| Query contains Lucene special characters: +field:value -excluded AND (grouped OR terms) NOT "exact phrase" wildcard* fuzzy~2 /regex/ [range TO search] | 3 | 1 | +| 192.168.1.1 - - [15/Jan/2024:10:30:03 +0000] "GET /api/products?search=laptop&category=electronics HTTP/1.1" 200 1234 "-" "Mozilla/5.0" | 4 | 2 | ++--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------+ +``` + + +## Example 5: Clustering with ngramset algorithm + +The following query uses the `ngramset` algorithm which compares character trigrams for fuzzy matching: + +```ppl +source=otellogs +| cluster body match=ngramset showcount=true +| fields body, cluster_label, cluster_count +| head 4 +``` + +The query returns the following results: + +```text +fetched rows / total rows = 4/4 ++--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------+ +| body | cluster_label | cluster_count | +|--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------| +| User e1ce63e6-8501-11f0-930d-c2fcbdc05f14 adding 4 of product HQTGWGPNH4 to cart | 1 | 1 | +| Payment failed: Insufficient funds for user@example.com | 2 | 1 | +| Query contains Lucene special characters: +field:value -excluded AND (grouped OR terms) NOT "exact phrase" wildcard* fuzzy~2 /regex/ [range TO search] | 3 | 1 | +| 192.168.1.1 - - [15/Jan/2024:10:30:03 +0000] "GET /api/products?search=laptop&category=electronics HTTP/1.1" 200 1234 "-" "Mozilla/5.0" | 4 | 1 | ++--------------------------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------+ +``` + + +## Example 6: Custom field names + +The following query uses custom field names for the cluster label and count: + +```ppl +source=otellogs +| cluster body match=termset t=0.3 labelfield=log_group countfield=group_size showcount=true +| fields body, log_group, group_size +| head 4 +``` + +The query returns the following results: + +```text +fetched rows / total rows = 4/4 ++--------------------------------------------------------------------------------------------------------------------------------------------------------+-----------+------------+ +| body | log_group | group_size | +|--------------------------------------------------------------------------------------------------------------------------------------------------------+-----------+------------| +| User e1ce63e6-8501-11f0-930d-c2fcbdc05f14 adding 4 of product HQTGWGPNH4 to cart | 1 | 3 | +| Payment failed: Insufficient funds for user@example.com | 2 | 3 | +| Query contains Lucene special characters: +field:value -excluded AND (grouped OR terms) NOT "exact phrase" wildcard* fuzzy~2 /regex/ [range TO search] | 3 | 1 | +| Email notification sent to john.doe+newsletter@company.com with subject: 'Welcome! Your order #12345 is confirmed' | 4 | 1 | ++--------------------------------------------------------------------------------------------------------------------------------------------------------+-----------+------------+ +``` + + +## Example 7: Label only mode + +The following query adds cluster labels to all rows without deduplicating. By default (`labelonly=false`), only the first representative row per cluster is kept: + +```ppl +source=otellogs +| cluster body match=termset t=0.3 labelonly=true showcount=true +| fields body, cluster_label, cluster_count +| head 5 +``` + +The query returns the following results: + +```text +fetched rows / total rows = 5/5 ++-----------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------+ +| body | cluster_label | cluster_count | +|-----------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------| +| User e1ce63e6-8501-11f0-930d-c2fcbdc05f14 adding 4 of product HQTGWGPNH4 to cart | 1 | 3 | +| 192.168.1.1 - - [15/Jan/2024:10:30:03 +0000] "GET /api/products?search=laptop&category=electronics HTTP/1.1" 200 1234 "-" "Mozilla/5.0" | 1 | 3 | +| [2024-01-15 10:30:09] production.INFO: User authentication successful for admin@company.org using OAuth2 | 1 | 3 | +| Payment failed: Insufficient funds for user@example.com | 2 | 3 | +| Elasticsearch query failed: {"query":{"bool":{"must":[{"match":{"email":"*@example.com"}}]}}} | 2 | 3 | ++-----------------------------------------------------------------------------------------------------------------------------------------+---------------+---------------+ +``` diff --git a/docs/user/ppl/index.md b/docs/user/ppl/index.md index 27f59fa4b95..f71b5a32809 100644 --- a/docs/user/ppl/index.md +++ b/docs/user/ppl/index.md @@ -55,6 +55,7 @@ source=accounts | [bin command](cmd/bin.md) | 3.3 | experimental (since 3.3) | Group numeric values into buckets of equal intervals. | | [timechart command](cmd/timechart.md) | 3.3 | experimental (since 3.3) | Create time-based charts and visualizations. | | [chart command](cmd/chart.md) | 3.4 | experimental (since 3.4) | Apply statistical aggregations to search results and group the data for visualizations. | +| [cluster command](cmd/cluster.md) | 3.7 | experimental (since 3.7) | Group documents into clusters based on text similarity using various clustering algorithms. | | [trendline command](cmd/trendline.md) | 3.0 | experimental (since 3.0) | Calculate moving averages of fields. | | [sort command](cmd/sort.md) | 1.0 | stable (since 1.0) | Sort all the search results by the specified fields. | | [reverse command](cmd/reverse.md) | 3.2 | experimental (since 3.2) | Reverse the display order of search results. | diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java index 014091ec072..c38a0515b34 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/CalciteNoPushdownIT.java @@ -26,6 +26,7 @@ CalciteConvertCommandIT.class, CalciteArrayFunctionIT.class, CalciteBinCommandIT.class, + CalciteClusterCommandIT.class, CalciteConvertTZFunctionIT.class, CalciteCsvFormatIT.class, CalciteDataTypeIT.class, diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteClusterCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteClusterCommandIT.java new file mode 100644 index 00000000000..b506361f5cc --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteClusterCommandIT.java @@ -0,0 +1,215 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.calcite.remote; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.ppl.PPLIntegTestCase; + +public class CalciteClusterCommandIT extends PPLIntegTestCase { + + @Override + public void init() throws Exception { + super.init(); + enableCalcite(); + loadIndex(Index.BANK); + } + + @Test + public void testBasicCluster() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | eval message = 'user login failed' | cluster message | fields" + + " cluster_label | head 1", + TEST_INDEX_BANK)); + verifySchema(result, schema("cluster_label", null, "int")); + verifyDataRows(result, rows(1)); + } + + @Test + public void testClusterWithCustomThreshold() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | eval message = 'error connecting to database' | cluster message" + + " t=0.8 | fields cluster_label | head 1", + TEST_INDEX_BANK)); + verifySchema(result, schema("cluster_label", null, "int")); + verifyDataRows(result, rows(1)); + } + + @Test + public void testClusterWithTermsetMatch() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | eval message = 'user authentication failed' | cluster message" + + " match=termset | fields cluster_label | head 1", + TEST_INDEX_BANK)); + verifySchema(result, schema("cluster_label", null, "int")); + verifyDataRows(result, rows(1)); + } + + @Test + public void testClusterWithNgramsetMatch() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | eval message = 'connection timeout error' | cluster message" + + " match=ngramset | fields cluster_label | head 1", + TEST_INDEX_BANK)); + verifySchema(result, schema("cluster_label", null, "int")); + verifyDataRows(result, rows(1)); + } + + @Test + public void testClusterWithCustomLabelField() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | eval message = 'database error occurred' | cluster message" + + " labelfield=my_cluster | fields my_cluster | head 1", + TEST_INDEX_BANK)); + verifySchema(result, schema("my_cluster", null, "int")); + verifyDataRows(result, rows(1)); + } + + @Test + public void testClusterWithShowCount() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | eval message = 'server unavailable' | cluster message" + + " showcount=true | fields cluster_label, cluster_count | head 1", + TEST_INDEX_BANK)); + verifySchema( + result, schema("cluster_label", null, "int"), schema("cluster_count", null, "bigint")); + verifyDataRows(result, rows(1, 7)); + } + + @Test + public void testClusterWithCustomCountField() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | eval message = 'server unavailable' | cluster message" + + " countfield=my_count showcount=true | fields cluster_label, my_count" + + " | head 1", + TEST_INDEX_BANK)); + verifySchema(result, schema("cluster_label", null, "int"), schema("my_count", null, "bigint")); + verifyDataRows(result, rows(1, 7)); + } + + @Test + public void testClusterWithDelimiters() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | eval message = 'user-login-failed' | cluster message delims='-'" + + " | fields cluster_label | head 1", + TEST_INDEX_BANK)); + verifySchema(result, schema("cluster_label", null, "int")); + verifyDataRows(result, rows(1)); + } + + @Test + public void testClusterWithAllParameters() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | eval message = 'system error detected' | cluster message t=0.7" + + " match=termset labelfield=custom_label countfield=custom_count" + + " showcount=true | fields custom_label, custom_count | head 1", + TEST_INDEX_BANK)); + verifySchema( + result, schema("custom_label", null, "int"), schema("custom_count", null, "bigint")); + verifyDataRows(result, rows(1, 7)); + } + + @Test + public void testClusterPreservesOtherFields() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | eval message = 'system alert' | cluster message | fields" + + " account_number, message, cluster_label | head 1", + TEST_INDEX_BANK)); + verifySchema( + result, + schema("account_number", null, "bigint"), + schema("message", null, "string"), + schema("cluster_label", null, "int")); + verifyDataRows(result, rows(1, "system alert", 1)); + } + + @Test + public void testClusterGroupsSimilarMessages() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | eval message = case(account_number=1, 'login failed for user" + + " admin', account_number=6, 'login failed for user root' else 'login" + + " failed for user guest') | cluster message match=termset showcount=true" + + " | fields message, cluster_label, cluster_count", + TEST_INDEX_BANK)); + verifySchema( + result, + schema("message", null, "string"), + schema("cluster_label", null, "int"), + schema("cluster_count", null, "bigint")); + // All similar messages should dedup to one representative row + verifyDataRows(result, rows("login failed for user admin", 1, 7)); + } + + @Test + public void testClusterDedupsByDefault() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | eval message = case(account_number=1, 'login failed for user" + + " admin', account_number=6, 'login failed for user root' else 'login" + + " failed for user guest') | cluster message match=termset showcount=true" + + " | fields message, cluster_label, cluster_count", + TEST_INDEX_BANK)); + verifySchema( + result, + schema("message", null, "string"), + schema("cluster_label", null, "int"), + schema("cluster_count", null, "bigint")); + verifyDataRows(result, rows("login failed for user admin", 1, 7)); + } + + @Test + public void testClusterLabelOnlyKeepsAllRows() throws IOException { + JSONObject result = + executeQuery( + String.format( + "search source=%s | eval message = case(account_number=1, 'login failed for user" + + " admin', account_number=6, 'login failed for user root' else 'login" + + " failed for user guest') | cluster message match=termset labelonly=true" + + " showcount=true | fields message, cluster_label, cluster_count | head 3", + TEST_INDEX_BANK)); + verifySchema( + result, + schema("message", null, "string"), + schema("cluster_label", null, "int"), + schema("cluster_count", null, "bigint")); + verifyDataRows( + result, + rows("login failed for user admin", 1, 7), + rows("login failed for user root", 1, 7), + rows("login failed for user guest", 1, 7)); + } + +} diff --git a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java index 73531a8895c..bb17020fc2c 100644 --- a/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/calcite/remote/CalciteExplainIT.java @@ -606,6 +606,17 @@ public void testExplainBinWithStartEnd() throws IOException { "source=opensearch-sql_test_index_account | bin balance start=0 end=100001 | head 5")); } + @Test + public void testExplainClusterCommand() throws IOException { + String query = + "source=opensearch-sql_test_index_account | eval message='login error' | cluster message |" + + " head 5"; + var result = explainQueryYaml(query); + + String expected = loadExpectedPlan("explain_cluster.yaml"); + assertYamlEqualsIgnoreId(expected, result); + } + @Test public void testExplainBinWithAligntime() throws IOException { String expected = loadExpectedPlan("explain_bin_aligntime.yaml"); diff --git a/integ-test/src/test/java/org/opensearch/sql/security/CrossClusterSearchIT.java b/integ-test/src/test/java/org/opensearch/sql/security/CrossClusterSearchIT.java index 0029921c1fc..2f574549169 100644 --- a/integ-test/src/test/java/org/opensearch/sql/security/CrossClusterSearchIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/security/CrossClusterSearchIT.java @@ -6,6 +6,7 @@ package org.opensearch.sql.security; import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_OTEL_LOGS; import static org.opensearch.sql.util.MatcherUtils.columnName; import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.verifyColumn; @@ -27,6 +28,7 @@ protected void init() throws Exception { loadIndex(Index.DOG); loadIndex(Index.DOG, remoteClient()); loadIndex(Index.ACCOUNT); + loadIndex(Index.OTELLOGS); } @Test @@ -237,4 +239,49 @@ public void testCrossClusterConvertWithAlias() throws IOException { disableCalcite(); } + + @Test + public void testCrossClusterClusterCommand() throws IOException { + enableCalcite(); + + JSONObject result = + executeQuery( + String.format( + "search source=%s | cluster body | fields cluster_label", TEST_INDEX_OTEL_LOGS)); + verifyColumn(result, columnName("cluster_label")); + + disableCalcite(); + } + + @Test + public void testCrossClusterClusterCommandWithParameters() throws IOException { + enableCalcite(); + + JSONObject result = + executeQuery( + String.format( + "search source=%s | cluster body t=0.8 match=termset showcount=true" + + " | fields cluster_label, cluster_count, body", + TEST_INDEX_OTEL_LOGS)); + verifyColumn( + result, columnName("cluster_label"), columnName("cluster_count"), columnName("body")); + + disableCalcite(); + } + + @Test + public void testCrossClusterClusterCommandMultiCluster() throws IOException { + enableCalcite(); + + JSONObject result = + executeQuery( + String.format( + "search source=%s | cluster body showcount=true | fields" + + " cluster_label, cluster_count, body", + TEST_INDEX_OTEL_LOGS)); + verifyColumn( + result, columnName("cluster_label"), columnName("cluster_count"), columnName("body")); + + disableCalcite(); + } } diff --git a/integ-test/src/test/resources/expectedOutput/calcite/explain_cluster.yaml b/integ-test/src/test/resources/expectedOutput/calcite/explain_cluster.yaml new file mode 100644 index 00000000000..b99f903d278 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite/explain_cluster.yaml @@ -0,0 +1,21 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], message=[$17], cluster_label=[$18]) + LogicalSort(fetch=[5]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], _id=[$11], _index=[$12], _score=[$13], _maxscore=[$14], _sort=[$15], _routing=[$16], message=[$17], cluster_label=[$18]) + LogicalFilter(condition=[=($19, 1)]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], _id=[$11], _index=[$12], _score=[$13], _maxscore=[$14], _sort=[$15], _routing=[$16], message=[$17], cluster_label=[$18], _cluster_convergence_row_num=[ROW_NUMBER() OVER (PARTITION BY $18)]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], _id=[$11], _index=[$12], _score=[$13], _maxscore=[$14], _sort=[$15], _routing=[$16], message=[$17], cluster_label=[ITEM($18, CAST(ROW_NUMBER() OVER ()):INTEGER NOT NULL)]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], _id=[$11], _index=[$12], _score=[$13], _maxscore=[$14], _sort=[$15], _routing=[$16], message=['login error':VARCHAR], _cluster_labels_array=[cluster_label('login error':VARCHAR, 0.8E0:DOUBLE, 'termlist':VARCHAR, 'non-alphanumeric':VARCHAR) OVER ()]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableCalc(expr#0..13=[{inputs}], proj#0..12=[{exprs}]) + EnumerableLimit(fetch=[5]) + EnumerableCalc(expr#0..13=[{inputs}], expr#14=[1], expr#15=[=($t13, $t14)], proj#0..13=[{exprs}], $condition=[$t15]) + EnumerableWindow(window#0=[window(partition {12} rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])]) + EnumerableCalc(expr#0..12=[{inputs}], expr#13=['login error':VARCHAR], expr#14=[CAST($t12):INTEGER NOT NULL], expr#15=[ITEM($t11, $t14)], proj#0..10=[{exprs}], message=[$t13], cluster_label=[$t15]) + EnumerableWindow(window#0=[window(rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])]) + EnumerableWindow(window#0=[window(aggs [cluster_label($11, $12, $13, $14)])], constants=[['login error':VARCHAR, 0.8E0:DOUBLE, 'termlist':VARCHAR, 'non-alphanumeric':VARCHAR]]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]], PushDownContext=[[PROJECT->[account_number, firstname, address, balance, gender, city, employer, state, age, email, lastname]], OpenSearchRequestBuilder(sourceBuilder={"from":0,"timeout":"1m","_source":{"includes":["account_number","firstname","address","balance","gender","city","employer","state","age","email","lastname"],"excludes":[]}}, requestedTotalSize=2147483647, pageSize=null, startFrom=0)]) \ No newline at end of file diff --git a/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_cluster.yaml b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_cluster.yaml new file mode 100644 index 00000000000..92dd7b82f64 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/calcite_no_pushdown/explain_cluster.yaml @@ -0,0 +1,22 @@ +calcite: + logical: | + LogicalSystemLimit(fetch=[10000], type=[QUERY_SIZE_LIMIT]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], message=[$17], cluster_label=[$18]) + LogicalSort(fetch=[5]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], _id=[$11], _index=[$12], _score=[$13], _maxscore=[$14], _sort=[$15], _routing=[$16], message=[$17], cluster_label=[$18]) + LogicalFilter(condition=[=($19, 1)]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], _id=[$11], _index=[$12], _score=[$13], _maxscore=[$14], _sort=[$15], _routing=[$16], message=[$17], cluster_label=[$18], _cluster_convergence_row_num=[ROW_NUMBER() OVER (PARTITION BY $18)]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], _id=[$11], _index=[$12], _score=[$13], _maxscore=[$14], _sort=[$15], _routing=[$16], message=[$17], cluster_label=[ITEM($18, CAST(ROW_NUMBER() OVER ()):INTEGER NOT NULL)]) + LogicalProject(account_number=[$0], firstname=[$1], address=[$2], balance=[$3], gender=[$4], city=[$5], employer=[$6], state=[$7], age=[$8], email=[$9], lastname=[$10], _id=[$11], _index=[$12], _score=[$13], _maxscore=[$14], _sort=[$15], _routing=[$16], message=['login error':VARCHAR], _cluster_labels_array=[cluster_label('login error':VARCHAR, 0.8E0:DOUBLE, 'termlist':VARCHAR, 'non-alphanumeric':VARCHAR) OVER ()]) + CalciteLogicalIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) + physical: | + EnumerableLimit(fetch=[10000]) + EnumerableCalc(expr#0..13=[{inputs}], proj#0..12=[{exprs}]) + EnumerableLimit(fetch=[5]) + EnumerableCalc(expr#0..13=[{inputs}], expr#14=[1], expr#15=[=($t13, $t14)], proj#0..13=[{exprs}], $condition=[$t15]) + EnumerableWindow(window#0=[window(partition {12} rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])]) + EnumerableCalc(expr#0..12=[{inputs}], expr#13=['login error':VARCHAR], expr#14=[CAST($t12):INTEGER NOT NULL], expr#15=[ITEM($t11, $t14)], proj#0..10=[{exprs}], message=[$t13], cluster_label=[$t15]) + EnumerableWindow(window#0=[window(rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])]) + EnumerableWindow(window#0=[window(aggs [cluster_label($11, $12, $13, $14)])], constants=[['login error':VARCHAR, 0.8E0:DOUBLE, 'termlist':VARCHAR, 'non-alphanumeric':VARCHAR]]) + EnumerableCalc(expr#0..16=[{inputs}], proj#0..10=[{exprs}]) + CalciteEnumerableIndexScan(table=[[OpenSearch, opensearch-sql_test_index_account]]) diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 3ada5c96b72..3361cc8de8a 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -42,6 +42,10 @@ PATTERN: 'PATTERN'; PATTERNS: 'PATTERNS'; NEW_FIELD: 'NEW_FIELD'; KMEANS: 'KMEANS'; +CLUSTER_CMD: 'CLUSTER'; +TERMLIST: 'TERMLIST'; +TERMSET: 'TERMSET'; +NGRAMSET: 'NGRAMSET'; AD: 'AD'; ML: 'ML'; FILLNULL: 'FILLNULL'; @@ -173,6 +177,9 @@ APPEND: 'APPEND'; MULTISEARCH: 'MULTISEARCH'; COUNTFIELD: 'COUNTFIELD'; SHOWCOUNT: 'SHOWCOUNT'; +LABELONLY: 'LABELONLY'; +DELIMS: 'DELIMS'; +T: 'T'; LIMIT: 'LIMIT'; USEOTHER: 'USEOTHER'; OTHERSTR: 'OTHERSTR'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 72d5d0fd76d..9bc7f7e390b 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -71,6 +71,7 @@ commands | patternsCommand | lookupCommand | kmeansCommand + | clusterCommand | adCommand | mlCommand | fillnullCommand @@ -122,6 +123,7 @@ commandName | PATTERNS | LOOKUP | KMEANS + | CLUSTER_CMD | AD | ML | FILLNULL @@ -605,6 +607,26 @@ kmeansParameter | (DISTANCE_TYPE EQUAL distance_type = stringLiteral) ; +clusterCommand + : CLUSTER_CMD source_field = expression (clusterParameter)* + ; + +clusterParameter + : (MATCH EQUAL match = clusterMatchMode) + | (LABELFIELD EQUAL labelfield = qualifiedName) + | (COUNTFIELD EQUAL countfield = qualifiedName) + | (LABELONLY EQUAL labelonly = booleanLiteral) + | (SHOWCOUNT EQUAL showcount = booleanLiteral) + | (DELIMS EQUAL delims = stringLiteral) + | (T EQUAL t = decimalLiteral) + ; + +clusterMatchMode + : TERMLIST + | TERMSET + | NGRAMSET + ; + adCommand : AD (adParameter)* ; @@ -1570,10 +1592,9 @@ identifierSeq ; ident - : (DOT)? ID + : (DOT)? (ID | keywordsCanBeId) | BACKTICK ident BACKTICK | BQUOTA_STRING - | keywordsCanBeId ; tableIdent @@ -1742,4 +1763,11 @@ searchableKeyWord | MAX_DEPTH | DEPTH_FIELD | EDGE + // CLUSTER COMMAND KEYWORDS + | T + | DELIMS + | LABELONLY + | TERMLIST + | TERMSET + | NGRAMSET ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 9a92126d2e6..89a7f99f0d8 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -1130,6 +1130,41 @@ public UnresolvedPlan visitKmeansCommand(KmeansCommandContext ctx) { return new Kmeans(builder.build()); } + /** Cluster command. */ + @Override + public UnresolvedPlan visitClusterCommand(OpenSearchPPLParser.ClusterCommandContext ctx) { + UnresolvedExpression sourceField = internalVisitExpression(ctx.source_field); + + double threshold = 0.8; + String matchMode = "termlist"; + String labelField = "cluster_label"; + String countField = "cluster_count"; + boolean labelOnly = false; + boolean showCount = false; + String delims = "non-alphanumeric"; + + for (OpenSearchPPLParser.ClusterParameterContext param : ctx.clusterParameter()) { + if (param.match != null) { + matchMode = param.match.getText().toLowerCase(java.util.Locale.ROOT); + } else if (param.labelfield != null) { + labelField = StringUtils.unquoteText(param.labelfield.getText()); + } else if (param.countfield != null) { + countField = StringUtils.unquoteText(param.countfield.getText()); + } else if (param.labelonly != null) { + labelOnly = Boolean.parseBoolean(param.labelonly.getText()); + } else if (param.showcount != null) { + showCount = Boolean.parseBoolean(param.showcount.getText()); + } else if (param.delims != null) { + delims = StringUtils.unquoteText(param.delims.getText()); + } else if (param.t != null) { + threshold = Double.parseDouble(param.t.getText()); + } + } + + return new org.opensearch.sql.ast.tree.Cluster( + sourceField, threshold, matchMode, labelField, countField, labelOnly, showCount, delims); + } + /** AD command. */ @Override public UnresolvedPlan visitAdCommand(AdCommandContext ctx) { diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index 96c0787d5e3..2ed02d9636e 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -967,6 +967,25 @@ public String visitPatterns(Patterns node, String context) { return builder.toString(); } + @Override + public String visitCluster(org.opensearch.sql.ast.tree.Cluster node, String context) { + String child = node.getChild().get(0).accept(this, context); + String sourceField = visitExpression(node.getSourceField()); + StringBuilder command = new StringBuilder(); + command.append(child).append(" | cluster ").append(sourceField); + + command.append(" t=").append(node.getThreshold()); + + if (!"termlist".equals(node.getMatchMode())) { + command.append(" match=").append(node.getMatchMode()); + } + + command.append(" labelfield=").append(MASK_COLUMN); + command.append(" countfield=").append(MASK_COLUMN); + + return command.toString(); + } + private String groupBy(String groupBy) { return Strings.isNullOrEmpty(groupBy) ? "" : StringUtils.format("by %s", groupBy); } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/TokenizationAnalysisTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/TokenizationAnalysisTest.java new file mode 100644 index 00000000000..a6ec89174f3 --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/TokenizationAnalysisTest.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import static org.junit.Assert.assertTrue; + +import org.antlr.v4.runtime.CommonTokenStream; +import org.antlr.v4.runtime.Token; +import org.junit.Test; +import org.opensearch.sql.common.antlr.CaseInsensitiveCharStream; +import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLLexer; + +public class TokenizationAnalysisTest { + + @Test + public void analyzeTokenization() { + String[] inputs = {"c:t", "c:.t", ".t", "t", "c:test", "c:.test"}; + + for (String input : inputs) { + OpenSearchPPLLexer lexer = new OpenSearchPPLLexer(new CaseInsensitiveCharStream(input)); + CommonTokenStream tokens = new CommonTokenStream(lexer); + tokens.fill(); + + // Verify tokenization succeeds and produces tokens + assertTrue("Should produce at least one token", tokens.getTokens().size() > 0); + + for (Token token : tokens.getTokens()) { + if (token.getType() != Token.EOF) { + String tokenName = OpenSearchPPLLexer.VOCABULARY.getSymbolicName(token.getType()); + // Verify token has valid type and text + assertTrue("Token should have valid type", token.getType() > 0); + assertTrue("Token should have non-null text", token.getText() != null); + } + } + } + } +} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLClusterTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLClusterTest.java new file mode 100644 index 00000000000..e6cc1e15cda --- /dev/null +++ b/ppl/src/test/java/org/opensearch/sql/ppl/calcite/CalcitePPLClusterTest.java @@ -0,0 +1,258 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.calcite; + +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.test.CalciteAssert; +import org.junit.Test; + +public class CalcitePPLClusterTest extends CalcitePPLAbstractTest { + + public CalcitePPLClusterTest() { + super(CalciteAssert.SchemaSpec.SCOTT_WITH_TEMPORAL); + } + + @Test + public void testBasicCluster() { + String ppl = "source=EMP | cluster ENAME"; + RelNode root = getRelNode(ppl); + + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7], cluster_label=[$8])\n" + + " LogicalFilter(condition=[=($9, 1)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], cluster_label=[$8]," + + " _cluster_convergence_row_num=[ROW_NUMBER() OVER (PARTITION BY $8)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], cluster_label=[ITEM($8, CAST(ROW_NUMBER() OVER" + + " ()):INTEGER NOT NULL)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _cluster_labels_array=[cluster_label($1," + + " 0.8E0:DOUBLE, 'termlist':VARCHAR, 'non-alphanumeric':VARCHAR) OVER ()])\n" + + " LogicalFilter(condition=[IS NOT NULL($1)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " `cluster_label`\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " `cluster_label`, ROW_NUMBER() OVER (PARTITION BY `cluster_label`)" + + " `_cluster_convergence_row_num`\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " `_cluster_labels_array`[CAST(ROW_NUMBER() OVER () AS INTEGER)] `cluster_label`\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " `cluster_label`(`ENAME`, 8E-1, 'termlist', 'non-alphanumeric') OVER (RANGE BETWEEN" + + " UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) `_cluster_labels_array`\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `ENAME` IS NOT NULL) `t0`) `t1`) `t2`\n" + + "WHERE `_cluster_convergence_row_num` = 1"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testClusterWithThreshold() { + String ppl = "source=EMP | cluster ENAME t=0.8"; + RelNode root = getRelNode(ppl); + + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7], cluster_label=[$8])\n" + + " LogicalFilter(condition=[=($9, 1)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], cluster_label=[$8]," + + " _cluster_convergence_row_num=[ROW_NUMBER() OVER (PARTITION BY $8)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], cluster_label=[ITEM($8, CAST(ROW_NUMBER() OVER" + + " ()):INTEGER NOT NULL)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _cluster_labels_array=[cluster_label($1," + + " 0.8E0:DOUBLE, 'termlist':VARCHAR, 'non-alphanumeric':VARCHAR) OVER ()])\n" + + " LogicalFilter(condition=[IS NOT NULL($1)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + } + + @Test + public void testClusterWithTermsetMatch() { + String ppl = "source=EMP | cluster ENAME match=termset"; + RelNode root = getRelNode(ppl); + + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7], cluster_label=[$8])\n" + + " LogicalFilter(condition=[=($9, 1)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], cluster_label=[$8]," + + " _cluster_convergence_row_num=[ROW_NUMBER() OVER (PARTITION BY $8)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], cluster_label=[ITEM($8, CAST(ROW_NUMBER() OVER" + + " ()):INTEGER NOT NULL)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _cluster_labels_array=[cluster_label($1," + + " 0.8E0:DOUBLE, 'termset':VARCHAR, 'non-alphanumeric':VARCHAR) OVER ()])\n" + + " LogicalFilter(condition=[IS NOT NULL($1)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + } + + @Test + public void testClusterWithNgramsetMatch() { + String ppl = "source=EMP | cluster ENAME match=ngramset"; + RelNode root = getRelNode(ppl); + + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7], cluster_label=[$8])\n" + + " LogicalFilter(condition=[=($9, 1)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], cluster_label=[$8]," + + " _cluster_convergence_row_num=[ROW_NUMBER() OVER (PARTITION BY $8)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], cluster_label=[ITEM($8, CAST(ROW_NUMBER() OVER" + + " ()):INTEGER NOT NULL)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _cluster_labels_array=[cluster_label($1," + + " 0.8E0:DOUBLE, 'ngramset':VARCHAR, 'non-alphanumeric':VARCHAR) OVER ()])\n" + + " LogicalFilter(condition=[IS NOT NULL($1)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + } + + @Test + public void testClusterWithCustomFields() { + String ppl = "source=EMP | cluster ENAME labelfield=my_cluster countfield=my_count"; + RelNode root = getRelNode(ppl); + + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7], my_cluster=[$8])\n" + + " LogicalFilter(condition=[=($9, 1)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], my_cluster=[$8]," + + " _cluster_convergence_row_num=[ROW_NUMBER() OVER (PARTITION BY $8)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], my_cluster=[ITEM($8, CAST(ROW_NUMBER() OVER" + + " ()):INTEGER NOT NULL)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _cluster_labels_array=[cluster_label($1," + + " 0.8E0:DOUBLE, 'termlist':VARCHAR, 'non-alphanumeric':VARCHAR) OVER ()])\n" + + " LogicalFilter(condition=[IS NOT NULL($1)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + } + + @Test + public void testClusterWithAllParameters() { + String ppl = + "source=EMP | cluster ENAME t=0.7 match=termset labelfield=cluster_id" + + " countfield=cluster_size showcount=true labelonly=false delims=' '"; + RelNode root = getRelNode(ppl); + + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7], cluster_id=[$8], cluster_size=[$9])\n" + + " LogicalFilter(condition=[=($10, 1)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], cluster_id=[$8], cluster_size=[$9]," + + " _cluster_convergence_row_num=[ROW_NUMBER() OVER (PARTITION BY $8)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], cluster_id=[$8], cluster_size=[COUNT() OVER" + + " (PARTITION BY $8)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], cluster_id=[ITEM($8, CAST(ROW_NUMBER() OVER" + + " ()):INTEGER NOT NULL)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _cluster_labels_array=[cluster_label($1," + + " 0.7E0:DOUBLE, 'termset':VARCHAR, ' ') OVER ()])\n" + + " LogicalFilter(condition=[IS NOT NULL($1)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + } + + @Test + public void testClusterOnDifferentField() { + String ppl = "source=EMP | cluster JOB"; + RelNode root = getRelNode(ppl); + + String expectedSparkSql = + "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " `cluster_label`\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " `cluster_label`, ROW_NUMBER() OVER (PARTITION BY `cluster_label`)" + + " `_cluster_convergence_row_num`\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " `_cluster_labels_array`[CAST(ROW_NUMBER() OVER () AS INTEGER)] `cluster_label`\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " `cluster_label`(`JOB`, 8E-1, 'termlist', 'non-alphanumeric') OVER (RANGE BETWEEN" + + " UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) `_cluster_labels_array`\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `JOB` IS NOT NULL) `t0`) `t1`) `t2`\n" + + "WHERE `_cluster_convergence_row_num` = 1"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testClusterLabelOnly() { + String ppl = "source=EMP | cluster ENAME labelonly=true"; + RelNode root = getRelNode(ppl); + + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7], cluster_label=[ITEM($8, CAST(ROW_NUMBER() OVER" + + " ()):INTEGER NOT NULL)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _cluster_labels_array=[cluster_label($1," + + " 0.8E0:DOUBLE, 'termlist':VARCHAR, 'non-alphanumeric':VARCHAR) OVER ()])\n" + + " LogicalFilter(condition=[IS NOT NULL($1)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " `_cluster_labels_array`[CAST(ROW_NUMBER() OVER () AS INTEGER)] `cluster_label`\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " `cluster_label`(`ENAME`, 8E-1, 'termlist', 'non-alphanumeric') OVER (RANGE BETWEEN" + + " UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) `_cluster_labels_array`\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `ENAME` IS NOT NULL) `t0`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } + + @Test + public void testClusterLabelOnlyWithShowCount() { + String ppl = "source=EMP | cluster ENAME labelonly=true showcount=true"; + RelNode root = getRelNode(ppl); + + String expectedLogical = + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5]," + + " COMM=[$6], DEPTNO=[$7], cluster_label=[$8], cluster_count=[COUNT() OVER" + + " (PARTITION BY $8)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], cluster_label=[ITEM($8, CAST(ROW_NUMBER() OVER" + + " ()):INTEGER NOT NULL)])\n" + + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4]," + + " SAL=[$5], COMM=[$6], DEPTNO=[$7], _cluster_labels_array=[cluster_label($1," + + " 0.8E0:DOUBLE, 'termlist':VARCHAR, 'non-alphanumeric':VARCHAR) OVER ()])\n" + + " LogicalFilter(condition=[IS NOT NULL($1)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; + verifyLogical(root, expectedLogical); + + String expectedSparkSql = + "SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " `cluster_label`, COUNT(*) OVER (PARTITION BY `cluster_label` RANGE BETWEEN" + + " UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) `cluster_count`\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " `_cluster_labels_array`[CAST(ROW_NUMBER() OVER () AS INTEGER)] `cluster_label`\n" + + "FROM (SELECT `EMPNO`, `ENAME`, `JOB`, `MGR`, `HIREDATE`, `SAL`, `COMM`, `DEPTNO`," + + " `cluster_label`(`ENAME`, 8E-1, 'termlist', 'non-alphanumeric') OVER (RANGE BETWEEN" + + " UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) `_cluster_labels_array`\n" + + "FROM `scott`.`EMP`\n" + + "WHERE `ENAME` IS NOT NULL) `t0`) `t1`"; + verifyPPLToSparkSQL(root, expectedSparkSql); + } +} diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java index bb720bd4207..f18ae58163e 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java @@ -1013,6 +1013,26 @@ public void testMultisearch() { + " [search source=accounts | where age = 25]")); } + @Test + public void testClusterCommand() { + assertEquals( + "source=table | cluster identifier t=0.8 labelfield=identifier countfield=identifier", + anonymize("source=t | cluster message")); + assertEquals( + "source=table | cluster identifier t=0.8 labelfield=identifier countfield=identifier", + anonymize("source=t | cluster message t=0.8")); + assertEquals( + "source=table | cluster identifier t=0.8 match=termset labelfield=identifier" + + " countfield=identifier", + anonymize("source=t | cluster message match=termset")); + assertEquals( + "source=table | cluster identifier t=0.7 match=ngramset labelfield=identifier" + + " countfield=identifier", + anonymize( + "source=t | cluster message t=0.7 match=ngramset labelfield=cluster_label" + + " countfield=cluster_count")); + } + private String anonymize(String query) { AstBuilder astBuilder = new AstBuilder(query, settings); return anonymize(astBuilder.visit(parser.parse(query)));