Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion python/sedonadb/tests/functions/test_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
# specific language governing permissions and limitations
# under the License.

import geopandas
import geopandas.testing
import pytest
import shapely
from sedonadb.testing import PostGIS, SedonaDB
from sedonadb.testing import PostGIS, SedonaDB, skip_if_not_exists


# Aggregate functions don't have a suffix in PostGIS
Expand Down Expand Up @@ -136,6 +138,36 @@ def test_st_envelope_agg_many_groups(eng, con):
eng.assert_result(result, expected)


@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
def test_st_envelope_nontrivial_input(eng, geoarrow_data):
path = geoarrow_data / "ns-water" / "files" / "ns-water_water-point_geo.parquet"
eng = eng.create_or_skip()
skip_if_not_exists(path)

df_points_geopandas = geopandas.read_parquet(path)
expected = (
df_points_geopandas.groupby(df_points_geopandas.FEAT_CODE)["geometry"]
.apply(
lambda group: shapely.Point(*group.total_bounds[:2])
if len(group) == 1
else shapely.box(*group.total_bounds)
)
.reset_index()
).set_crs(df_points_geopandas.crs)

eng.create_table_parquet("pts", path)
result = eng.execute_and_collect(f"""
SELECT "FEAT_CODE", {call_st_envelope_agg(eng, "geometry")} AS geometry
FROM pts
GROUP BY "FEAT_CODE"
ORDER BY "FEAT_CODE"
""")

# This CRS is too complicated to check roundtripping through PostGIS
df = eng.result_to_pandas(result)
geopandas.testing.assert_geodataframe_equal(df, expected, check_crs=False)


@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
def test_st_collect_points(eng):
eng = eng.create_or_skip()
Expand Down
52 changes: 37 additions & 15 deletions rust/sedona-functions/src/st_envelope_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ impl BoundsGroupsAccumulator2D {
group_indices: &[usize],
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
input_type: SedonaType,
) -> Result<()> {
// Check some of our assumptions about how this will be called
debug_assert_eq!(self.offset, 0);
Expand All @@ -211,7 +212,7 @@ impl BoundsGroupsAccumulator2D {
debug_assert_eq!(values[0].len(), filter.len());
}

let arg_types = [self.input_type.clone()];
let arg_types = [input_type.clone()];
let args = [ColumnarValue::Array(values[0].clone())];
let executor = WkbExecutor::new(&arg_types, &args);
self.xs.resize(total_num_groups, Interval::empty());
Expand Down Expand Up @@ -300,7 +301,13 @@ impl GroupsAccumulator for BoundsGroupsAccumulator2D {
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
self.execute_update(values, group_indices, opt_filter, total_num_groups)
self.execute_update(
values,
group_indices,
opt_filter,
total_num_groups,
self.input_type.clone(),
)
}

fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
Expand All @@ -314,8 +321,15 @@ impl GroupsAccumulator for BoundsGroupsAccumulator2D {
opt_filter: Option<&arrow_array::BooleanArray>,
total_num_groups: usize,
) -> Result<()> {
// In this case, our state is identical to our input values
self.execute_update(values, group_indices, opt_filter, total_num_groups)
// In this case, our state is identical to our input values except our geometry
// representation is always WKB_GEOMETRY.
self.execute_update(
values,
group_indices,
opt_filter,
total_num_groups,
WKB_GEOMETRY,
)
}

fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
Expand All @@ -333,7 +347,9 @@ impl GroupsAccumulator for BoundsGroupsAccumulator2D {
mod test {
use datafusion_expr::AggregateUDF;
use rstest::rstest;
use sedona_schema::datatypes::{WKB_GEOMETRY_ITEM_CRS, WKB_VIEW_GEOMETRY};
use sedona_schema::datatypes::{
WKB_GEOMETRY_ITEM_CRS, WKB_VIEW_GEOMETRY, WKB_VIEW_GEOMETRY_ITEM_CRS,
};
use sedona_testing::{
compare::{assert_array_equal, assert_scalar_equal, assert_scalar_equal_wkb_geometry},
create::{create_array, create_scalar},
Expand Down Expand Up @@ -400,36 +416,42 @@ mod test {
);
}

#[test]
fn udf_invoke_item_crs() {
let sedona_type = WKB_GEOMETRY_ITEM_CRS.clone();
#[rstest]
fn udf_invoke_item_crs(
#[values(WKB_GEOMETRY_ITEM_CRS.clone(), WKB_VIEW_GEOMETRY_ITEM_CRS.clone())]
sedona_type: SedonaType,
) {
let tester =
AggregateUdfTester::new(st_envelope_agg_udf().into(), vec![sedona_type.clone()]);
assert_eq!(tester.return_type().unwrap(), sedona_type.clone());
assert_eq!(tester.return_type().unwrap(), WKB_GEOMETRY_ITEM_CRS.clone());

let batches = vec![
vec![Some("POINT (0 1)"), None, Some("POINT (2 3)")],
vec![Some("POINT (4 5)"), None, Some("POINT (6 7)")],
];
let expected = create_scalar(Some("POLYGON((0 1, 0 7, 6 7, 6 1, 0 1))"), &sedona_type);
let expected = create_scalar(
Some("POLYGON((0 1, 0 7, 6 7, 6 1, 0 1))"),
&WKB_GEOMETRY_ITEM_CRS,
);

assert_scalar_equal(&tester.aggregate_wkt(batches).unwrap(), &expected);
}

#[test]
fn udf_grouped_accumulate() {
let tester = AggregateUdfTester::new(st_envelope_agg_udf().into(), vec![WKB_GEOMETRY]);
#[rstest]
fn udf_grouped_accumulate(#[values(WKB_GEOMETRY, WKB_VIEW_GEOMETRY)] sedona_type: SedonaType) {
let tester =
AggregateUdfTester::new(st_envelope_agg_udf().into(), vec![sedona_type.clone()]);
assert_eq!(tester.return_type().unwrap(), WKB_GEOMETRY);

// Six elements, four groups, with one all null group and one partially null group
let group_indices = vec![0, 3, 1, 1, 0, 2];
let array0 = create_array(
&[Some("POINT (0 1)"), None, Some("POINT (2 3)")],
&WKB_GEOMETRY,
&sedona_type,
);
let array1 = create_array(
&[Some("POINT (4 5)"), None, Some("POINT (6 7)")],
&WKB_GEOMETRY,
&sedona_type,
);
let batches = vec![array0, array1];

Expand Down
8 changes: 8 additions & 0 deletions rust/sedona-schema/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,18 @@ pub const RASTER: SedonaType = SedonaType::Raster;
pub static WKB_GEOMETRY_ITEM_CRS: LazyLock<SedonaType> =
LazyLock::new(|| SedonaType::new_item_crs(&WKB_GEOMETRY).unwrap());

/// Sentinel for [SedonaType::new_item_crs] containing [WKB_VIEW_GEOMETRY]
pub static WKB_VIEW_GEOMETRY_ITEM_CRS: LazyLock<SedonaType> =
LazyLock::new(|| SedonaType::new_item_crs(&WKB_VIEW_GEOMETRY).unwrap());

/// Sentinel for [SedonaType::new_item_crs] containing [WKB_GEOGRAPHY]
pub static WKB_GEOGRAPHY_ITEM_CRS: LazyLock<SedonaType> =
LazyLock::new(|| SedonaType::new_item_crs(&WKB_GEOGRAPHY).unwrap());

/// Sentinel for [SedonaType::new_item_crs] containing [WKB_VIEW_GEOGRAPHY]
pub static WKB_VIEW_GEOGRAPHY_ITEM_CRS: LazyLock<SedonaType> =
LazyLock::new(|| SedonaType::new_item_crs(&WKB_VIEW_GEOGRAPHY).unwrap());

/// Create a static value for the [`SedonaType::Raster`] that's initialized exactly once,
/// on first access
static RASTER_DATATYPE: LazyLock<DataType> =
Expand Down