Skip to content

Commit 8cc8e1a

Browse files
committed
Pushing the docs to dev/ for branch: main, commit 43550f011a3adcb86fa09a7d508a9929ab87e49d
1 parent c616218 commit 8cc8e1a

File tree

1,361 files changed

+4774
-4766
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

1,361 files changed

+4774
-4766
lines changed
Binary file not shown.

dev/_downloads/6383d955c013c730f9d211f15e261f38/plot_successive_halving_heatmap.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,10 @@
5454

5555
def make_heatmap(ax, gs, is_sh=False, make_cbar=False):
5656
"""Helper to make a heatmap."""
57-
results = pd.DataFrame.from_dict(gs.cv_results_)
58-
results["params_str"] = results.params.apply(str)
57+
results = pd.DataFrame(gs.cv_results_)
58+
results[["param_C", "param_gamma"]] = results[["param_C", "param_gamma"]].astype(
59+
np.float64
60+
)
5961
if is_sh:
6062
# SH dataframe: get mean_test_score values for the highest iter
6163
scores_matrix = results.sort_values("iter").pivot_table(

dev/_downloads/69f1bc3bab6ea5d622c5dd4cbd78227f/plot_successive_halving_heatmap.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
},
6363
"outputs": [],
6464
"source": [
65-
"def make_heatmap(ax, gs, is_sh=False, make_cbar=False):\n \"\"\"Helper to make a heatmap.\"\"\"\n results = pd.DataFrame.from_dict(gs.cv_results_)\n results[\"params_str\"] = results.params.apply(str)\n if is_sh:\n # SH dataframe: get mean_test_score values for the highest iter\n scores_matrix = results.sort_values(\"iter\").pivot_table(\n index=\"param_gamma\",\n columns=\"param_C\",\n values=\"mean_test_score\",\n aggfunc=\"last\",\n )\n else:\n scores_matrix = results.pivot(\n index=\"param_gamma\", columns=\"param_C\", values=\"mean_test_score\"\n )\n\n im = ax.imshow(scores_matrix)\n\n ax.set_xticks(np.arange(len(Cs)))\n ax.set_xticklabels([\"{:.0E}\".format(x) for x in Cs])\n ax.set_xlabel(\"C\", fontsize=15)\n\n ax.set_yticks(np.arange(len(gammas)))\n ax.set_yticklabels([\"{:.0E}\".format(x) for x in gammas])\n ax.set_ylabel(\"gamma\", fontsize=15)\n\n # Rotate the tick labels and set their alignment.\n plt.setp(ax.get_xticklabels(), rotation=45, ha=\"right\", rotation_mode=\"anchor\")\n\n if is_sh:\n iterations = results.pivot_table(\n index=\"param_gamma\", columns=\"param_C\", values=\"iter\", aggfunc=\"max\"\n ).values\n for i in range(len(gammas)):\n for j in range(len(Cs)):\n ax.text(\n j,\n i,\n iterations[i, j],\n ha=\"center\",\n va=\"center\",\n color=\"w\",\n fontsize=20,\n )\n\n if make_cbar:\n fig.subplots_adjust(right=0.8)\n cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n fig.colorbar(im, cax=cbar_ax)\n cbar_ax.set_ylabel(\"mean_test_score\", rotation=-90, va=\"bottom\", fontsize=15)\n\n\nfig, axes = plt.subplots(ncols=2, sharey=True)\nax1, ax2 = axes\n\nmake_heatmap(ax1, gsh, is_sh=True)\nmake_heatmap(ax2, gs, make_cbar=True)\n\nax1.set_title(\"Successive Halving\\ntime = {:.3f}s\".format(gsh_time), fontsize=15)\nax2.set_title(\"GridSearch\\ntime = {:.3f}s\".format(gs_time), fontsize=15)\n\nplt.show()"
65+
"def make_heatmap(ax, gs, is_sh=False, make_cbar=False):\n \"\"\"Helper to make a heatmap.\"\"\"\n results = pd.DataFrame(gs.cv_results_)\n results[[\"param_C\", \"param_gamma\"]] = results[[\"param_C\", \"param_gamma\"]].astype(\n np.float64\n )\n if is_sh:\n # SH dataframe: get mean_test_score values for the highest iter\n scores_matrix = results.sort_values(\"iter\").pivot_table(\n index=\"param_gamma\",\n columns=\"param_C\",\n values=\"mean_test_score\",\n aggfunc=\"last\",\n )\n else:\n scores_matrix = results.pivot(\n index=\"param_gamma\", columns=\"param_C\", values=\"mean_test_score\"\n )\n\n im = ax.imshow(scores_matrix)\n\n ax.set_xticks(np.arange(len(Cs)))\n ax.set_xticklabels([\"{:.0E}\".format(x) for x in Cs])\n ax.set_xlabel(\"C\", fontsize=15)\n\n ax.set_yticks(np.arange(len(gammas)))\n ax.set_yticklabels([\"{:.0E}\".format(x) for x in gammas])\n ax.set_ylabel(\"gamma\", fontsize=15)\n\n # Rotate the tick labels and set their alignment.\n plt.setp(ax.get_xticklabels(), rotation=45, ha=\"right\", rotation_mode=\"anchor\")\n\n if is_sh:\n iterations = results.pivot_table(\n index=\"param_gamma\", columns=\"param_C\", values=\"iter\", aggfunc=\"max\"\n ).values\n for i in range(len(gammas)):\n for j in range(len(Cs)):\n ax.text(\n j,\n i,\n iterations[i, j],\n ha=\"center\",\n va=\"center\",\n color=\"w\",\n fontsize=20,\n )\n\n if make_cbar:\n fig.subplots_adjust(right=0.8)\n cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.7])\n fig.colorbar(im, cax=cbar_ax)\n cbar_ax.set_ylabel(\"mean_test_score\", rotation=-90, va=\"bottom\", fontsize=15)\n\n\nfig, axes = plt.subplots(ncols=2, sharey=True)\nax1, ax2 = axes\n\nmake_heatmap(ax1, gsh, is_sh=True)\nmake_heatmap(ax2, gs, make_cbar=True)\n\nax1.set_title(\"Successive Halving\\ntime = {:.3f}s\".format(gsh_time), fontsize=15)\nax2.set_title(\"GridSearch\\ntime = {:.3f}s\".format(gs_time), fontsize=15)\n\nplt.show()"
6666
]
6767
},
6868
{
Binary file not shown.

dev/_downloads/scikit-learn-docs.zip

-25.2 KB
Binary file not shown.
-1 Bytes
1 Byte
72 Bytes
-171 Bytes
-9 Bytes

0 commit comments

Comments
 (0)