Skip to content

Commit 1eb2642

Browse files
committed
Examples for case 0b
1 parent 70f7a95 commit 1eb2642

File tree

2 files changed

+91
-7
lines changed

2 files changed

+91
-7
lines changed

slep006/cases_opt0a.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy as np
22

3-
from defs import (accuracy, group_cv, make_scorer, SelectKBest,
3+
from defs import (accuracy, group_cv, get_scorer, SelectKBest,
44
LogisticRegressionCV, cross_validate,
55
make_pipeline, X, y, my_groups, my_weights,
66
my_other_weights)
@@ -28,7 +28,7 @@ def split(self, X, y, groups=None):
2828

2929
def get_n_splits(self, X, y, groups=None):
3030
groups = X[:, self.groups_idx]
31-
return self.base_cv.split(unwrap_X(X), y, groups=groups)
31+
return self.base_cv.get_n_splits(unwrap_X(X), y, groups=groups)
3232

3333

3434
wrapped_group_cv = WrappedGroupCV(group_cv)
@@ -39,11 +39,11 @@ def fit(self, X, y):
3939
return super().fit(unwrap_X(X), y, sample_weight=X[:, WEIGHT_IDX])
4040

4141

42-
weighted_acc = make_scorer(accuracy, request_props=['sample_weight'])
42+
acc_scorer = get_scorer('accuracy')
4343

4444

4545
def wrapped_weighted_acc(est, X, y, sample_weight=None):
46-
return weighted_acc(est, unwrap_X(X), y, sample_weight=X[:, WEIGHT_IDX])
46+
return acc_scorer(est, unwrap_X(X), y, sample_weight=X[:, WEIGHT_IDX])
4747

4848

4949
lr = WrappedLogisticRegressionCV(
@@ -81,7 +81,7 @@ def fit(self, X, y):
8181

8282
lr = WrappedLogisticRegressionCV(
8383
cv=wrapped_group_cv,
84-
scoring=weighted_acc,
84+
scoring=wrapped_weighted_acc,
8585
).set_props_request(['sample_weight'])
8686
sel = UnweightedWrappedSelectKBest()
8787
pipe = make_pipeline(sel, lr)

slep006/cases_opt0b.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,91 @@
11
import pandas as pd
2-
from defs import (accuracy, group_cv, make_scorer, SelectKBest,
2+
from defs import (accuracy, group_cv, get_scorer, SelectKBest,
33
LogisticRegressionCV, cross_validate,
44
make_pipeline, X, y, my_groups, my_weights,
55
my_other_weights)
66

7-
# TODO
7+
X = pd.DataFrame(X)
8+
MY_GROUPS = pd.Series(my_groups)
9+
MY_WEIGHTS = pd.Series(my_weights)
10+
MY_OTHER_WEIGHTS = pd.Series(my_other_weights)
11+
12+
# %%
13+
# Case A: weighted scoring and fitting
14+
15+
16+
class WrappedGroupCV:
17+
def __init__(self, base_cv):
18+
self.base_cv = base_cv
19+
20+
def split(self, X, y, groups=None):
21+
return self.base_cv.split(X, y, groups=MY_GROUPS.loc[X.index])
22+
23+
def get_n_splits(self, X, y, groups=None):
24+
return self.base_cv.get_n_splits(X, y, groups=MY_GROUPS.loc[X.index])
25+
26+
27+
wrapped_group_cv = WrappedGroupCV(group_cv)
28+
29+
30+
class WeightedLogisticRegressionCV(LogisticRegressionCV):
31+
def fit(self, X, y):
32+
return super().fit(X, y, sample_weight=MY_WEIGHTS.loc[X.index])
33+
34+
35+
acc_scorer = get_scorer('accuracy')
36+
37+
38+
def wrapped_weighted_acc(est, X, y, sample_weight=None):
39+
return acc_scorer(est, X, y, sample_weight=MY_WEIGHTS.loc[X.index])
40+
41+
42+
lr = WeightedLogisticRegressionCV(
43+
cv=wrapped_group_cv,
44+
scoring=wrapped_weighted_acc,
45+
).set_props_request(['sample_weight'])
46+
cross_validate(lr, X, y,
47+
cv=wrapped_group_cv,
48+
scoring=wrapped_weighted_acc)
49+
50+
# %%
51+
# Case B: weighted scoring and unweighted fitting
52+
53+
lr = LogisticRegressionCV(
54+
cv=wrapped_group_cv,
55+
scoring=wrapped_weighted_acc,
56+
).set_props_request(['sample_weight'])
57+
cross_validate(lr, X, y,
58+
cv=wrapped_group_cv,
59+
scoring=wrapped_weighted_acc)
60+
61+
62+
# %%
63+
# Case C: unweighted feature selection
64+
65+
lr = WeightedLogisticRegressionCV(
66+
cv=wrapped_group_cv,
67+
scoring=wrapped_weighted_acc,
68+
).set_props_request(['sample_weight'])
69+
sel = SelectKBest()
70+
pipe = make_pipeline(sel, lr)
71+
cross_validate(pipe, X, y,
72+
cv=wrapped_group_cv,
73+
scoring=wrapped_weighted_acc)
74+
75+
# %%
76+
# Case D: different scoring and fitting weights
77+
78+
79+
def other_weighted_acc(est, X, y, sample_weight=None):
80+
return acc_scorer(est, X, y, sample_weight=MY_OTHER_WEIGHTS.loc[X.index])
81+
82+
83+
lr = WeightedLogisticRegressionCV(
84+
cv=wrapped_group_cv,
85+
scoring=other_weighted_acc,
86+
).set_props_request(['sample_weight'])
87+
sel = SelectKBest()
88+
pipe = make_pipeline(sel, lr)
89+
cross_validate(pipe, X, y,
90+
cv=wrapped_group_cv,
91+
scoring=other_weighted_acc)

0 commit comments

Comments
 (0)