Skip to content

Commit 56d56ad

Browse files
committed
Pushing the docs to dev/ for branch: main, commit 439bda4752f6208f30938b534d3cb93fbdc4f285
1 parent 88f6530 commit 56d56ad

File tree

1,288 files changed

+5962
-5174
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,288 files changed

+5962
-5174
lines changed
Binary file not shown.

dev/_downloads/1dcd684ce26b8c407ec2c2d2101c5c73/plot_kernel_ridge_regression.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -189,35 +189,27 @@
189189
# %%
190190
# Visualize the learning curves
191191
# -----------------------------
192+
from sklearn.model_selection import LearningCurveDisplay
192193

193-
from sklearn.model_selection import learning_curve
194-
195-
plt.figure()
194+
_, ax = plt.subplots()
196195

197196
svr = SVR(kernel="rbf", C=1e1, gamma=0.1)
198197
kr = KernelRidge(kernel="rbf", alpha=0.1, gamma=0.1)
199-
train_sizes, train_scores_svr, test_scores_svr = learning_curve(
200-
svr,
201-
X[:100],
202-
y[:100],
203-
train_sizes=np.linspace(0.1, 1, 10),
204-
scoring="neg_mean_squared_error",
205-
cv=10,
206-
)
207-
train_sizes_abs, train_scores_kr, test_scores_kr = learning_curve(
208-
kr,
209-
X[:100],
210-
y[:100],
211-
train_sizes=np.linspace(0.1, 1, 10),
212-
scoring="neg_mean_squared_error",
213-
cv=10,
214-
)
215198

216-
plt.plot(train_sizes, -test_scores_kr.mean(1), "o--", color="g", label="KRR")
217-
plt.plot(train_sizes, -test_scores_svr.mean(1), "o--", color="r", label="SVR")
218-
plt.xlabel("Train size")
219-
plt.ylabel("Mean Squared Error")
220-
plt.title("Learning curves")
221-
plt.legend(loc="best")
199+
common_params = {
200+
"X": X[:100],
201+
"y": y[:100],
202+
"train_sizes": np.linspace(0.1, 1, 10),
203+
"scoring": "neg_mean_squared_error",
204+
"negate_score": True,
205+
"score_name": "Mean Squared Error",
206+
"std_display_style": None,
207+
"ax": ax,
208+
}
209+
210+
LearningCurveDisplay.from_estimator(svr, **common_params)
211+
LearningCurveDisplay.from_estimator(kr, **common_params)
212+
ax.set_title("Learning curves")
213+
ax.legend(handles=ax.get_legend_handles_labels()[0], labels=["SVR", "KRR"])
222214

223215
plt.show()

0 commit comments

Comments
 (0)