Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 104 additions & 6 deletions xrspatial/tests/test_zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,23 @@ def data_values_3d(backend):

@pytest.fixture
def result_default_stats():
expected_result = {
'zone': [0, 1, 2, 3],
'mean': [0, 1, 2, 2.4],
'max': [0, 1, 2, 3],
'min': [0, 1, 2, 0],
'sum': [0, 6, 8, 12],
'std': [0, 0, 0, 1.2],
'var': [0, 0, 0, 1.44],
'count': [5, 6, 4, 5],
'majority': [0, 1, 2, 3]
}
return expected_result


@pytest.fixture
def result_default_stats_no_majority():
"""Expected result for dask backend which doesn't support majority."""
expected_result = {
'zone': [0, 1, 2, 3],
'mean': [0, 1, 2, 2.4],
Expand Down Expand Up @@ -102,13 +119,35 @@ def result_default_stats_dataarray():

[[5., 5., 6., 6., 4., 4., 5., 5.],
[5., 5., 6., 6., 4., 4., 5., 5.],
[5., 5., 6., 6., 4., np.nan, 5., 5.]]]
[5., 5., 6., 6., 4., np.nan, 5., 5.]],

[[0., 0., 1., 1., 2., 2., 3., 3.],
[0., 0., 1., 1., 2., 2., 3., 3.],
[0., 0., 1., 1., 2., np.nan, 3., 3.]]]
)
return expected_result


@pytest.fixture
def result_zone_ids_stats():
zone_ids = [0, 3]
expected_result = {
'zone': [0, 3],
'mean': [0, 2.4],
'max': [0, 3],
'min': [0, 0],
'sum': [0, 12],
'std': [0, 1.2],
'var': [0, 1.44],
'count': [5, 5],
'majority': [0, 3]
}
return zone_ids, expected_result


@pytest.fixture
def result_zone_ids_stats_no_majority():
"""Expected result for dask backend which doesn't support majority."""
zone_ids = [0, 3]
expected_result = {
'zone': [0, 3],
Expand Down Expand Up @@ -153,7 +192,11 @@ def result_zone_ids_stats_dataarray():

[[5., 5., np.nan, np.nan, np.nan, np.nan, 5., 5.],
[5., 5., np.nan, np.nan, np.nan, np.nan, 5., 5.],
[5., 5., np.nan, np.nan, np.nan, np.nan, 5., 5.]]])
[5., 5., np.nan, np.nan, np.nan, np.nan, 5., 5.]],

[[0., 0., np.nan, np.nan, np.nan, np.nan, 3., 3.],
[0., 0., np.nan, np.nan, np.nan, np.nan, 3., 3.],
[0., 0., np.nan, np.nan, np.nan, np.nan, 3., 3.]]])

return zone_ids, expected_result

Expand Down Expand Up @@ -362,7 +405,8 @@ def check_results(


@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy', 'cupy'])
def test_default_stats(backend, data_zones, data_values_2d, result_default_stats):
def test_default_stats(backend, data_zones, data_values_2d, result_default_stats,
result_default_stats_no_majority):
if backend == 'cupy' and not has_cuda_and_cupy():
pytest.skip("Requires CUDA and CuPy")

Expand All @@ -374,7 +418,9 @@ def test_default_stats(backend, data_zones, data_values_2d, result_default_stats
copied_data_values_2d = copy.deepcopy(data_values_2d)

df_result = stats(zones=data_zones, values=data_values_2d)
check_results(backend, df_result, result_default_stats)
# dask doesn't support majority stat (can't be computed block-by-block)
expected_result = result_default_stats_no_majority if 'dask' in backend else result_default_stats
check_results(backend, df_result, expected_result)

assert_input_data_unmodified(data_zones, copied_data_zones)
assert_input_data_unmodified(data_values_2d, copied_data_values_2d)
Expand Down Expand Up @@ -403,7 +449,8 @@ def test_default_stats_dataarray(


@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy', 'cupy'])
def test_zone_ids_stats(backend, data_zones, data_values_2d, result_zone_ids_stats):
def test_zone_ids_stats(backend, data_zones, data_values_2d, result_zone_ids_stats,
result_zone_ids_stats_no_majority):
if backend == 'cupy' and not has_cuda_and_cupy():
pytest.skip("Requires CUDA and CuPy")

Expand All @@ -414,7 +461,11 @@ def test_zone_ids_stats(backend, data_zones, data_values_2d, result_zone_ids_sta
copied_data_zones = copy.deepcopy(data_zones)
copied_data_values_2d = copy.deepcopy(data_values_2d)

zone_ids, expected_result = result_zone_ids_stats
# dask doesn't support majority stat (can't be computed block-by-block)
if 'dask' in backend:
zone_ids, expected_result = result_zone_ids_stats_no_majority
else:
zone_ids, expected_result = result_zone_ids_stats
df_result = stats(zones=data_zones, values=data_values_2d,
zone_ids=zone_ids)
check_results(backend, df_result, expected_result)
Expand Down Expand Up @@ -491,6 +542,53 @@ def test_custom_stats_dataarray(backend, data_zones, data_values_2d, result_cust
assert_input_data_unmodified(data_values_2d, copied_data_values_2d)


@pytest.mark.parametrize("backend", ['numpy', 'cupy'])
def test_majority_stats(backend, data_zones, data_values_2d):
"""Test that majority stat returns the most frequent value in each zone."""
if backend == 'cupy' and not has_cuda_and_cupy():
pytest.skip("Requires CUDA and CuPy")

# copy input data to verify they're unchanged after running the function
copied_data_zones = copy.deepcopy(data_zones)
copied_data_values_2d = copy.deepcopy(data_values_2d)

df_result = stats(zones=data_zones, values=data_values_2d, stats_funcs=['majority'])
expected_result = {
'zone': [0, 1, 2, 3],
'majority': [0, 1, 2, 3]
}
check_results(backend, df_result, expected_result)
assert_input_data_unmodified(data_zones, copied_data_zones)
assert_input_data_unmodified(data_values_2d, copied_data_values_2d)


@pytest.mark.parametrize("backend", ['numpy', 'cupy'])
def test_majority_with_ties(backend):
"""Test majority when there are ties - should return the smallest value."""
if backend == 'cupy' and not has_cuda_and_cupy():
pytest.skip("Requires CUDA and CuPy")

# Create test data with ties
zones_data = np.array([[1, 1, 1, 1],
[1, 1, 2, 2],
[2, 2, 2, 2]])
values_data = np.array([[1, 1, 2, 2], # zone 1 has two 1s and two 2s - tie
[3, 3, 5, 5], # zone 1 also has two 3s, zone 2 has two 5s
[5, 5, 6, 6]]) # zone 2 has two more 5s and two 6s

zones = create_test_raster(zones_data, backend)
values = create_test_raster(values_data, backend)

df_result = stats(zones=zones, values=values, stats_funcs=['majority'])
# Zone 1: values [1, 1, 2, 2, 3, 3] - three values with count 2, majority is 1 (smallest)
# Zone 2: values [5, 5, 5, 5, 6, 6] - majority is 5 (count 4)
expected_result = {
'zone': [1, 2],
'majority': [1, 5]
}
check_results(backend, df_result, expected_result)


def test_zonal_stats_against_qgis(elevation_raster_no_nans, raster, qgis_zonal_stats):
stats_funcs = list(set(qgis_zonal_stats.keys()) - set(['zone']))
zones_agg = create_test_raster(raster)
Expand Down
24 changes: 21 additions & 3 deletions xrspatial/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,21 @@ def _stats_count(data):
return stats_count


def _stats_majority(data):
if isinstance(data, np.ndarray):
# numpy case
values, counts = np.unique(data, return_counts=True)
return values[np.argmax(counts)]
elif isinstance(data, cupy.ndarray):
# cupy case
values, counts = cupy.unique(data, return_counts=True)
return values[cupy.argmax(counts)]
else:
# dask case
values, counts = da.unique(data, return_counts=True)
return values[da.argmax(counts)]


_DEFAULT_STATS = dict(
mean=lambda z: z.mean(),
max=lambda z: z.max(),
Expand All @@ -61,6 +76,7 @@ def _stats_count(data):
std=lambda z: z.std(),
var=lambda z: z.var(),
count=lambda z: _stats_count(z),
majority=lambda z: _stats_majority(z),
)


Expand Down Expand Up @@ -246,8 +262,9 @@ def _stats_dask_numpy(
stats_df = dd.concat([dd.from_dask_array(s) for s in stats_dict.values()], axis=1)
# name columns
stats_df.columns = stats_dict.keys()
# select columns
stats_df = stats_df[['zone'] + list(stats_funcs.keys())]
# select columns (only include stats that were actually computed)
computed_stats = [s for s in stats_funcs.keys() if s in stats_dict]
stats_df = stats_df[['zone'] + computed_stats]

if not select_all_zones:
# only return zones specified in `zone_ids`
Expand Down Expand Up @@ -414,6 +431,7 @@ def stats(
"std",
"var",
"count",
"majority",
],
nodata_values: Union[int, float] = None,
return_type: str = 'pandas.DataFrame',
Expand Down Expand Up @@ -449,7 +467,7 @@ def stats(
all zones will be used.

stats_funcs : dict, or list of strings, default=['mean', 'max', 'min',
'sum', 'std', 'var', 'count']
'sum', 'std', 'var', 'count', 'majority']
The statistics to calculate for each zone. If a list, possible
choices are subsets of the default options.
In the dictionary case, all of its values must be
Expand Down