Skip to content

Commit 70f7a95

Browse files
committed
Add variant for solution 4
1 parent cf39e6a commit 70f7a95

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

slep006/cases_opt4b.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from defs import (accuracy, group_cv, make_scorer, SelectKBest,
2+
LogisticRegressionCV, cross_validate,
3+
make_pipeline, X, y, my_groups, my_weights,
4+
my_other_weights)
5+
6+
# %%
7+
# Case A: weighted scoring and fitting
8+
9+
# Here we presume that GroupKFold requests `groups` by default.
10+
# We need to explicitly request weights in make_scorer and for
11+
# LogisticRegressionCV. Both of these consumers understand the meaning
12+
# of the key "sample_weight".
13+
14+
weighted_acc = make_scorer(accuracy, request_props=['sample_weight'])
15+
lr = LogisticRegressionCV(
16+
cv=group_cv,
17+
scoring=weighted_acc,
18+
).request_sample_weight(fit=['sample_weight'])
19+
cross_validate(lr, X, y, cv=group_cv,
20+
props={'sample_weight': my_weights, 'groups': my_groups},
21+
scoring=weighted_acc)
22+
23+
# Error handling: if props={'sample_eight': my_weights, ...} was passed,
24+
# cross_validate would raise an error, since 'sample_eight' was not requested
25+
# by any of its children.
26+
27+
# %%
28+
# Case B: weighted scoring and unweighted fitting
29+
30+
# Since LogisticRegressionCV requires that weights explicitly be requested,
31+
# removing that request means the fitting is unweighted.
32+
33+
weighted_acc = make_scorer(accuracy, request_props=['sample_weight'])
34+
lr = LogisticRegressionCV(
35+
cv=group_cv,
36+
scoring=weighted_acc,
37+
)
38+
cross_validate(lr, X, y, cv=group_cv,
39+
props={'sample_weight': my_weights, 'groups': my_groups},
40+
scoring=weighted_acc)
41+
42+
# %%
43+
# Case C: unweighted feature selection
44+
45+
# Like LogisticRegressionCV, SelectKBest needs to request weights explicitly.
46+
# Here it does not request them.
47+
48+
weighted_acc = make_scorer(accuracy, request_props=['sample_weight'])
49+
lr = LogisticRegressionCV(
50+
cv=group_cv,
51+
scoring=weighted_acc,
52+
).request_sample_weight(fit=['sample_weight'])
53+
sel = SelectKBest()
54+
pipe = make_pipeline(sel, lr)
55+
cross_validate(pipe, X, y, cv=group_cv,
56+
props={'sample_weight': my_weights, 'groups': my_groups},
57+
scoring=weighted_acc)
58+
59+
# %%
60+
# Case D: different scoring and fitting weights
61+
62+
# Despite make_scorer and LogisticRegressionCV both expecting a key
63+
# sample_weight, we can use aliases to pass different weights to different
64+
# consumers.
65+
66+
weighted_acc = make_scorer(accuracy,
67+
request_props={'scoring_weight': 'sample_weight'})
68+
lr = LogisticRegressionCV(
69+
cv=group_cv,
70+
scoring=weighted_acc,
71+
).request_sample_weight(fit='fitting_weight')
72+
cross_validate(lr, X, y, cv=group_cv,
73+
props={
74+
'scoring_weight': my_weights,
75+
'fitting_weight': my_other_weights,
76+
'groups': my_groups,
77+
},
78+
scoring=weighted_acc)

slep006/proposal.rst

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ Other related issues include: :issue:`1574`, :issue:`2630`, :issue:`3524`,
7979
:issue:`4632`, :issue:`4652`, :issue:`4660`, :issue:`4696`, :issue:`6322`,
8080
:issue:`7112`, :issue:`7646`, :issue:`7723`, :issue:`8127`, :issue:`8158`,
8181
:issue:`8710`, :issue:`8950`, :issue:`11429`, :issue:`12052`, :issue:`15282`,
82-
:issues:`15370`, :issue:`15425`.
82+
:issues:`15370`, :issue:`15425`, :issue:`18028`.
8383

8484
Desiderata
8585
----------
@@ -368,6 +368,14 @@ Disadvantages:
368368
`set_props_request` method (instead of the `request_props` constructor
369369
parameter approach) such that all legacy base estimators are
370370
automatically equipped.
371+
* Aliasing is a bit confusing in this design, in that the consumer still
372+
accepts the fit param by its original name (e.g. `sample_weight`) even if it
373+
has a request that specifies a different key given to the router (e.g.
374+
`fit_sample_weight`). This design has the advantage that the handling of
375+
props within a consumer is simple and unchanged; the complexity is in
376+
how it is forwarded the data by the router, but it may be conceptually
377+
difficult for users to understand. (This may be acceptable, as an advanced
378+
feature.)
371379
* For estimators to be cloned, this request information needs to be cloned with
372380
it. This implies one of: the request information be stored as a constructor
373381
paramerter; or `clone` is extended to explicitly copy request information.
@@ -389,6 +397,22 @@ Test cases:
389397

390398
.. literalinclude:: cases_opt4.py
391399

400+
Extensions and alternatives to the syntax considered while working on
401+
:pr:`16079`:
402+
403+
* `set_prop_request` and `get_props_request` have lists of props requested
404+
**for each method** i.e. fit, score, transform, predict and perhaps others.
405+
* `set_props_request` could be replaced by a method (or parameter) representing
406+
the routing of each prop that it consumes. For example, an estimator that
407+
consumes `sample_weight` would have a `request_sample_weight` method. One of
408+
the difficulties of this approach is automatically introducing
409+
`request_sample_weight` into classes inheriting from BaseEstimator without
410+
too much magic (e.g. meta-classes, which might be the simplest solution).
411+
412+
These are demonstrated together in the following:
413+
414+
.. literalinclude:: cases_opt4b.py
415+
392416
Naming
393417
------
394418

0 commit comments

Comments
 (0)