@@ -1141,6 +1141,7 @@ def get_orderings(args, grouper, grouped):
1141
1141
"""
1142
1142
orders = {} if "category_orders" not in args else args ["category_orders" ].copy ()
1143
1143
group_names = []
1144
+ group_values = {}
1144
1145
for group_name in grouped .groups :
1145
1146
if len (grouper ) == 1 :
1146
1147
group_name = (group_name ,)
@@ -1154,6 +1155,7 @@ def get_orderings(args, grouper, grouped):
1154
1155
for val in uniques :
1155
1156
if val not in orders [col ]:
1156
1157
orders [col ].append (val )
1158
+ group_values [col ] = sorted (uniques , key = orders [col ].index )
1157
1159
1158
1160
for i , col in reversed (list (enumerate (grouper ))):
1159
1161
if col != one_group :
@@ -1162,7 +1164,7 @@ def get_orderings(args, grouper, grouped):
1162
1164
key = lambda g : orders [col ].index (g [i ]) if g [i ] in orders [col ] else - 1 ,
1163
1165
)
1164
1166
1165
- return orders , group_names
1167
+ return orders , group_names , group_values
1166
1168
1167
1169
1168
1170
def make_figure (args , constructor , trace_patch = {}, layout_patch = {}):
@@ -1174,16 +1176,31 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1174
1176
grouper = [x .grouper or one_group for x in grouped_mappings ] or [one_group ]
1175
1177
grouped = args ["data_frame" ].groupby (grouper , sort = False )
1176
1178
1177
- orders , sorted_group_names = get_orderings (args , grouper , grouped )
1179
+ orders , sorted_group_names , sorted_group_values = get_orderings (
1180
+ args , grouper , grouped
1181
+ )
1182
+
1183
+ col_labels = []
1184
+ row_labels = []
1185
+
1186
+ for m in grouped_mappings :
1187
+ if m .grouper :
1188
+ if m .facet == "col" :
1189
+ prefix = get_label (args , args ["facet_col" ]) + "="
1190
+ col_labels = [prefix + str (s ) for s in sorted_group_values [m .grouper ]]
1191
+ if m .facet == "row" :
1192
+ prefix = get_label (args , args ["facet_row" ]) + "="
1193
+ row_labels = [prefix + str (s ) for s in sorted_group_values [m .grouper ]]
1194
+ for val in sorted_group_values [m .grouper ]:
1195
+ if val not in m .val_map :
1196
+ m .val_map [val ] = m .sequence [len (m .val_map ) % len (m .sequence )]
1178
1197
1179
1198
subplot_type = _subplot_type_for_trace_type (constructor ().type )
1180
1199
1181
1200
trace_names_by_frame = {}
1182
1201
frames = OrderedDict ()
1183
1202
trendline_rows = []
1184
1203
nrows = ncols = 1
1185
- col_labels = []
1186
- row_labels = []
1187
1204
trace_name_labels = None
1188
1205
for group_name in sorted_group_names :
1189
1206
group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
@@ -1281,10 +1298,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1281
1298
# Find row for trace, handling facet_row and marginal_x
1282
1299
if m .facet == "row" :
1283
1300
row = m .val_map [val ]
1284
- if args ["facet_row" ] and len (row_labels ) < row :
1285
- row_labels .append (
1286
- get_label (args , args ["facet_row" ]) + "=" + str (val )
1287
- )
1288
1301
else :
1289
1302
if (
1290
1303
bool (args .get ("marginal_x" , False ))
@@ -1298,10 +1311,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
1298
1311
# Find col for trace, handling facet_col and marginal_y
1299
1312
if m .facet == "col" :
1300
1313
col = m .val_map [val ]
1301
- if args ["facet_col" ] and len (col_labels ) < col :
1302
- col_labels .append (
1303
- get_label (args , args ["facet_col" ]) + "=" + str (val )
1304
- )
1305
1314
if facet_col_wrap : # assumes no facet_row, no marginals
1306
1315
row = 1 + ((col - 1 ) // facet_col_wrap )
1307
1316
col = 1 + ((col - 1 ) % facet_col_wrap )
0 commit comments