Skip to content

Commit 235cdce

Browse files
preload val_map from orders
1 parent c7234fc commit 235cdce

File tree

1 file changed

+21
-12
lines changed
  • packages/python/plotly/plotly/express

1 file changed

+21
-12
lines changed

packages/python/plotly/plotly/express/_core.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,6 +1141,7 @@ def get_orderings(args, grouper, grouped):
11411141
"""
11421142
orders = {} if "category_orders" not in args else args["category_orders"].copy()
11431143
group_names = []
1144+
group_values = {}
11441145
for group_name in grouped.groups:
11451146
if len(grouper) == 1:
11461147
group_name = (group_name,)
@@ -1154,6 +1155,7 @@ def get_orderings(args, grouper, grouped):
11541155
for val in uniques:
11551156
if val not in orders[col]:
11561157
orders[col].append(val)
1158+
group_values[col] = sorted(uniques, key=orders[col].index)
11571159

11581160
for i, col in reversed(list(enumerate(grouper))):
11591161
if col != one_group:
@@ -1162,7 +1164,7 @@ def get_orderings(args, grouper, grouped):
11621164
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
11631165
)
11641166

1165-
return orders, group_names
1167+
return orders, group_names, group_values
11661168

11671169

11681170
def make_figure(args, constructor, trace_patch={}, layout_patch={}):
@@ -1174,16 +1176,31 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
11741176
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
11751177
grouped = args["data_frame"].groupby(grouper, sort=False)
11761178

1177-
orders, sorted_group_names = get_orderings(args, grouper, grouped)
1179+
orders, sorted_group_names, sorted_group_values = get_orderings(
1180+
args, grouper, grouped
1181+
)
1182+
1183+
col_labels = []
1184+
row_labels = []
1185+
1186+
for m in grouped_mappings:
1187+
if m.grouper:
1188+
if m.facet == "col":
1189+
prefix = get_label(args, args["facet_col"]) + "="
1190+
col_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]]
1191+
if m.facet == "row":
1192+
prefix = get_label(args, args["facet_row"]) + "="
1193+
row_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]]
1194+
for val in sorted_group_values[m.grouper]:
1195+
if val not in m.val_map:
1196+
m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]
11781197

11791198
subplot_type = _subplot_type_for_trace_type(constructor().type)
11801199

11811200
trace_names_by_frame = {}
11821201
frames = OrderedDict()
11831202
trendline_rows = []
11841203
nrows = ncols = 1
1185-
col_labels = []
1186-
row_labels = []
11871204
trace_name_labels = None
11881205
for group_name in sorted_group_names:
11891206
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
@@ -1281,10 +1298,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12811298
# Find row for trace, handling facet_row and marginal_x
12821299
if m.facet == "row":
12831300
row = m.val_map[val]
1284-
if args["facet_row"] and len(row_labels) < row:
1285-
row_labels.append(
1286-
get_label(args, args["facet_row"]) + "=" + str(val)
1287-
)
12881301
else:
12891302
if (
12901303
bool(args.get("marginal_x", False))
@@ -1298,10 +1311,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
12981311
# Find col for trace, handling facet_col and marginal_y
12991312
if m.facet == "col":
13001313
col = m.val_map[val]
1301-
if args["facet_col"] and len(col_labels) < col:
1302-
col_labels.append(
1303-
get_label(args, args["facet_col"]) + "=" + str(val)
1304-
)
13051314
if facet_col_wrap: # assumes no facet_row, no marginals
13061315
row = 1 + ((col - 1) // facet_col_wrap)
13071316
col = 1 + ((col - 1) % facet_col_wrap)

0 commit comments

Comments
 (0)