Skip to content

Commit dc229e6

Browse files
rhshadrachasv-bot
andauthored
TST/CLN: Consolidate creation of groupby method args (#47973)
* TST/CLN: Consolidate creation of groupby method args * Rename to get_groupby_method_args Co-authored-by: asv-bot <[email protected]>
1 parent 5802686 commit dc229e6

File tree

7 files changed

+52
-63
lines changed

7 files changed

+52
-63
lines changed

pandas/tests/groupby/__init__.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
def get_groupby_method_args(name, obj):
2+
"""
3+
Get required arguments for a groupby method.
4+
5+
When parametrizing a test over groupby methods (e.g. "sum", "mean", "fillna"),
6+
it is often the case that arguments are required for certain methods.
7+
8+
Parameters
9+
----------
10+
name: str
11+
Name of the method.
12+
obj: Series or DataFrame
13+
pandas object that is being grouped.
14+
15+
Returns
16+
-------
17+
A tuple of required arguments for the method.
18+
"""
19+
if name in ("nth", "fillna", "take"):
20+
return (0,)
21+
if name == "quantile":
22+
return (0.5,)
23+
if name == "corrwith":
24+
return (obj,)
25+
if name == "tshift":
26+
return (0, 0)
27+
return ()

pandas/tests/groupby/test_apply.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
)
1818
import pandas._testing as tm
1919
from pandas.core.api import Int64Index
20+
from pandas.tests.groupby import get_groupby_method_args
2021

2122

2223
def test_apply_issues():
@@ -1069,7 +1070,7 @@ def test_apply_is_unchanged_when_other_methods_are_called_first(reduction_func):
10691070

10701071
# Check output when another method is called before .apply()
10711072
grp = df.groupby(by="a")
1072-
args = {"nth": [0], "corrwith": [df]}.get(reduction_func, [])
1073+
args = get_groupby_method_args(reduction_func, df)
10731074
with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"):
10741075
_ = getattr(grp, reduction_func)(*args)
10751076
result = grp.apply(sum)

pandas/tests/groupby/test_categorical.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
qcut,
1515
)
1616
import pandas._testing as tm
17+
from pandas.tests.groupby import get_groupby_method_args
1718

1819

1920
def cartesian_product_for_groupers(result, args, names, fill_value=np.NaN):
@@ -1373,7 +1374,7 @@ def test_series_groupby_on_2_categoricals_unobserved(reduction_func, observed, r
13731374
"value": [0.1] * 4,
13741375
}
13751376
)
1376-
args = {"nth": [0]}.get(reduction_func, [])
1377+
args = get_groupby_method_args(reduction_func, df)
13771378

13781379
expected_length = 4 if observed else 16
13791380

@@ -1409,7 +1410,7 @@ def test_series_groupby_on_2_categoricals_unobserved_zeroes_or_nans(
14091410
}
14101411
)
14111412
unobserved = [tuple("AC"), tuple("BC"), tuple("CA"), tuple("CB"), tuple("CC")]
1412-
args = {"nth": [0]}.get(reduction_func, [])
1413+
args = get_groupby_method_args(reduction_func, df)
14131414

14141415
series_groupby = df.groupby(["cat_1", "cat_2"], observed=False)["value"]
14151416
agg = getattr(series_groupby, reduction_func)
@@ -1450,7 +1451,7 @@ def test_dataframe_groupby_on_2_categoricals_when_observed_is_true(reduction_fun
14501451

14511452
df_grp = df.groupby(["cat_1", "cat_2"], observed=True)
14521453

1453-
args = {"nth": [0], "corrwith": [df]}.get(reduction_func, [])
1454+
args = get_groupby_method_args(reduction_func, df)
14541455
with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"):
14551456
res = getattr(df_grp, reduction_func)(*args)
14561457

@@ -1482,7 +1483,7 @@ def test_dataframe_groupby_on_2_categoricals_when_observed_is_false(
14821483

14831484
df_grp = df.groupby(["cat_1", "cat_2"], observed=observed)
14841485

1485-
args = {"nth": [0], "corrwith": [df]}.get(reduction_func, [])
1486+
args = get_groupby_method_args(reduction_func, df)
14861487
with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"):
14871488
res = getattr(df_grp, reduction_func)(*args)
14881489

pandas/tests/groupby/test_function.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
import pandas._testing as tm
2020
import pandas.core.nanops as nanops
21+
from pandas.tests.groupby import get_groupby_method_args
2122
from pandas.util import _test_decorators as td
2223

2324

@@ -570,7 +571,7 @@ def test_axis1_numeric_only(request, groupby_func, numeric_only):
570571
groups = [1, 2, 3, 1, 2, 3, 1, 2, 3, 4]
571572
gb = df.groupby(groups)
572573
method = getattr(gb, groupby_func)
573-
args = (0,) if groupby_func == "fillna" else ()
574+
args = get_groupby_method_args(groupby_func, df)
574575
kwargs = {"axis": 1}
575576
if numeric_only is not None:
576577
# when numeric_only is None we don't pass any argument
@@ -1366,12 +1367,7 @@ def test_deprecate_numeric_only(
13661367
# has_arg: Whether the op has a numeric_only arg
13671368
df = DataFrame({"a1": [1, 1], "a2": [2, 2], "a3": [5, 6], "b": 2 * [object]})
13681369

1369-
if kernel == "corrwith":
1370-
args = (df,)
1371-
elif kernel == "nth" or kernel == "fillna":
1372-
args = (0,)
1373-
else:
1374-
args = ()
1370+
args = get_groupby_method_args(kernel, df)
13751371
kwargs = {} if numeric_only is lib.no_default else {"numeric_only": numeric_only}
13761372

13771373
gb = df.groupby(keys)
@@ -1451,22 +1447,7 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request):
14511447
expected_gb = expected_ser.groupby(grouper)
14521448
expected_method = getattr(expected_gb, groupby_func)
14531449

1454-
if groupby_func == "corrwith":
1455-
args = (ser,)
1456-
elif groupby_func == "corr":
1457-
args = (ser,)
1458-
elif groupby_func == "cov":
1459-
args = (ser,)
1460-
elif groupby_func == "nth":
1461-
args = (0,)
1462-
elif groupby_func == "fillna":
1463-
args = (True,)
1464-
elif groupby_func == "take":
1465-
args = ([0],)
1466-
elif groupby_func == "quantile":
1467-
args = (0.5,)
1468-
else:
1469-
args = ()
1450+
args = get_groupby_method_args(groupby_func, ser)
14701451

14711452
fails_on_numeric_object = (
14721453
"corr",

pandas/tests/groupby/test_groupby.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from pandas.core.arrays import BooleanArray
3030
import pandas.core.common as com
3131
from pandas.core.groupby.base import maybe_normalize_deprecated_kernels
32+
from pandas.tests.groupby import get_groupby_method_args
3233

3334

3435
def test_repr():
@@ -2366,14 +2367,10 @@ def test_dup_labels_output_shape(groupby_func, idx):
23662367
df = DataFrame([[1, 1]], columns=idx)
23672368
grp_by = df.groupby([0])
23682369

2369-
args = []
2370-
if groupby_func in {"fillna", "nth"}:
2371-
args.append(0)
2372-
elif groupby_func == "corrwith":
2373-
args.append(df)
2374-
elif groupby_func == "tshift":
2370+
if groupby_func == "tshift":
23752371
df.index = [Timestamp("today")]
2376-
args.extend([1, "D"])
2372+
# args.extend([1, "D"])
2373+
args = get_groupby_method_args(groupby_func, df)
23772374

23782375
with tm.assert_produces_warning(warn, match="is deprecated"):
23792376
result = getattr(grp_by, groupby_func)(*args)

pandas/tests/groupby/test_groupby_subclass.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
import pandas._testing as tm
1212
from pandas.core.groupby.base import maybe_normalize_deprecated_kernels
13+
from pandas.tests.groupby import get_groupby_method_args
1314

1415

1516
@pytest.mark.parametrize(
@@ -34,13 +35,7 @@ def test_groupby_preserves_subclass(obj, groupby_func):
3435
# Groups should preserve subclass type
3536
assert isinstance(grouped.get_group(0), type(obj))
3637

37-
args = []
38-
if groupby_func in {"fillna", "nth"}:
39-
args.append(0)
40-
elif groupby_func == "corrwith":
41-
args.append(obj)
42-
elif groupby_func == "tshift":
43-
args.extend([0, 0])
38+
args = get_groupby_method_args(groupby_func, obj)
4439

4540
with tm.assert_produces_warning(warn, match="is deprecated"):
4641
result1 = getattr(grouped, groupby_func)(*args)

pandas/tests/groupby/transform/test_transform.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pandas._testing as tm
2323
from pandas.core.groupby.base import maybe_normalize_deprecated_kernels
2424
from pandas.core.groupby.generic import DataFrameGroupBy
25+
from pandas.tests.groupby import get_groupby_method_args
2526

2627

2728
def assert_fp_equal(a, b):
@@ -172,14 +173,10 @@ def test_transform_axis_1(request, transformation_func):
172173
msg = "ngroup fails with axis=1: #45986"
173174
request.node.add_marker(pytest.mark.xfail(reason=msg))
174175

175-
warn = None
176-
if transformation_func == "tshift":
177-
warn = FutureWarning
178-
179-
request.node.add_marker(pytest.mark.xfail(reason="tshift is deprecated"))
180-
args = ("ffill",) if transformation_func == "fillna" else ()
176+
warn = FutureWarning if transformation_func == "tshift" else None
181177

182178
df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"])
179+
args = get_groupby_method_args(transformation_func, df)
183180
with tm.assert_produces_warning(warn):
184181
result = df.groupby([0, 0, 1], axis=1).transform(transformation_func, *args)
185182
expected = df.T.groupby([0, 0, 1]).transform(transformation_func, *args).T
@@ -1168,7 +1165,7 @@ def test_transform_agg_by_name(request, reduction_func, obj):
11681165
pytest.mark.xfail(reason="TODO: implement SeriesGroupBy.corrwith")
11691166
)
11701167

1171-
args = {"nth": [0], "quantile": [0.5], "corrwith": [obj]}.get(func, [])
1168+
args = get_groupby_method_args(reduction_func, obj)
11721169
with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"):
11731170
result = g.transform(func, *args)
11741171

@@ -1370,12 +1367,7 @@ def test_null_group_str_reducer(request, dropna, reduction_func):
13701367
df = DataFrame({"A": [1, 1, np.nan, np.nan], "B": [1, 2, 2, 3]}, index=index)
13711368
gb = df.groupby("A", dropna=dropna)
13721369

1373-
if reduction_func == "corrwith":
1374-
args = (df["B"],)
1375-
elif reduction_func == "nth":
1376-
args = (0,)
1377-
else:
1378-
args = ()
1370+
args = get_groupby_method_args(reduction_func, df)
13791371

13801372
# Manually handle reducers that don't fit the generic pattern
13811373
# Set expected with dropna=False, then replace if necessary
@@ -1418,8 +1410,8 @@ def test_null_group_str_transformer(request, dropna, transformation_func):
14181410
if transformation_func == "tshift":
14191411
msg = "tshift requires timeseries"
14201412
request.node.add_marker(pytest.mark.xfail(reason=msg))
1421-
args = (0,) if transformation_func == "fillna" else ()
14221413
df = DataFrame({"A": [1, 1, np.nan], "B": [1, 2, 2]}, index=[1, 2, 3])
1414+
args = get_groupby_method_args(transformation_func, df)
14231415
gb = df.groupby("A", dropna=dropna)
14241416

14251417
buffer = []
@@ -1461,12 +1453,7 @@ def test_null_group_str_reducer_series(request, dropna, reduction_func):
14611453
ser = Series([1, 2, 2, 3], index=index)
14621454
gb = ser.groupby([1, 1, np.nan, np.nan], dropna=dropna)
14631455

1464-
if reduction_func == "corrwith":
1465-
args = (ser,)
1466-
elif reduction_func == "nth":
1467-
args = (0,)
1468-
else:
1469-
args = ()
1456+
args = get_groupby_method_args(reduction_func, ser)
14701457

14711458
# Manually handle reducers that don't fit the generic pattern
14721459
# Set expected with dropna=False, then replace if necessary
@@ -1506,8 +1493,8 @@ def test_null_group_str_transformer_series(request, dropna, transformation_func)
15061493
if transformation_func == "tshift":
15071494
msg = "tshift requires timeseries"
15081495
request.node.add_marker(pytest.mark.xfail(reason=msg))
1509-
args = (0,) if transformation_func == "fillna" else ()
15101496
ser = Series([1, 2, 2], index=[1, 2, 3])
1497+
args = get_groupby_method_args(transformation_func, ser)
15111498
gb = ser.groupby([1, 1, np.nan], dropna=dropna)
15121499

15131500
buffer = []

0 commit comments

Comments
 (0)