1
+ from typing import Callable
2
+
1
3
import jax
2
4
import jax ._src .lax_reference as lax_reference
3
5
import numpy
@@ -19,152 +21,83 @@ def all_concrete_values(data):
19
21
return True
20
22
21
23
22
- def symbolic_not ( * args , ** kwargs ):
24
+ def symbolic_f ( concrete_f : Callable , symbolic_f_name : str , * args , ** kwargs ):
23
25
if all_concrete_values ([* args ]):
24
- return numpy . logical_not (* args , ** kwargs )
26
+ return concrete_f (* args , ** kwargs )
25
27
else :
26
- return symbolic_operator .symbolic_operator ('numpy.logical_not' , * args , ** kwargs )
28
+ return symbolic_operator .symbolic_operator (symbolic_f_name , * args , ** kwargs )
29
+
30
+
31
+ def symbolic_not (* args , ** kwargs ):
32
+ return symbolic_f (numpy .logical_not , 'numpy.logical_not' , * args , ** kwargs )
27
33
28
34
29
35
def symbolic_eq (* args , ** kwargs ):
30
- if all_concrete_values ([* args ]):
31
- return lax_reference .eq (* args , ** kwargs )
32
- else :
33
- return symbolic_operator .symbolic_operator ('lax_reference.eq' , * args , ** kwargs )
36
+ return symbolic_f (lax_reference .eq , 'lax_reference.eq' , * args , ** kwargs )
34
37
35
38
36
39
def symbolic_ne (* args , ** kwargs ):
37
- if all_concrete_values ([* args ]):
38
- return lax_reference .ne (* args , ** kwargs )
39
- else :
40
- return symbolic_operator .symbolic_operator ('lax_reference.ne' , * args , ** kwargs )
40
+ return symbolic_f (lax_reference .ne , 'lax_reference.ne' , * args , ** kwargs )
41
41
42
42
43
43
def symbolic_le (* args , ** kwargs ):
44
- if all_concrete_values ([* args ]):
45
- return lax_reference .le (* args , ** kwargs )
46
- else :
47
- return symbolic_operator .symbolic_operator ('lax_reference.le' , * args , ** kwargs )
44
+ return symbolic_f (lax_reference .le , 'lax_reference.le' , * args , ** kwargs )
48
45
49
46
50
47
def symbolic_lt (* args , ** kwargs ):
51
- if all_concrete_values ([* args ]):
52
- return lax_reference .lt (* args , ** kwargs )
53
- else :
54
- return symbolic_operator .symbolic_operator ('lax_reference.lt' , * args , ** kwargs )
48
+ return symbolic_f (lax_reference .lt , 'lax_reference.lt' , * args , ** kwargs )
55
49
56
50
57
51
def symbolic_ge (* args , ** kwargs ):
58
- if all_concrete_values ([* args ]):
59
- return lax_reference .ge (* args , ** kwargs )
60
- else :
61
- return symbolic_operator .symbolic_operator ('lax_reference.ge' , * args , ** kwargs )
52
+ return symbolic_f (lax_reference .ge , 'lax_reference.ge' , * args , ** kwargs )
62
53
63
54
64
55
def symbolic_gt (* args , ** kwargs ):
65
- if all_concrete_values ([* args ]):
66
- return lax_reference .gt (* args , ** kwargs )
67
- else :
68
- return symbolic_operator .symbolic_operator ('lax_reference.gt' , * args , ** kwargs )
56
+ return symbolic_f (lax_reference .gt , 'lax_reference.gt' , * args , ** kwargs )
69
57
70
58
71
59
def symbolic_abs (* args , ** kwargs ):
72
- if all_concrete_values ([* args ]):
73
- return lax_reference .abs (* args , ** kwargs )
74
- else :
75
- return symbolic_operator .symbolic_operator ('numpy.absolute' , * args , ** kwargs )
60
+ return symbolic_f (lax_reference .abs , 'numpy.absolute' , * args , ** kwargs )
76
61
77
62
78
63
def symbolic_add (* args , ** kwargs ):
79
- if all_concrete_values ([* args ]):
80
- return lax_reference .add (* args , ** kwargs )
81
- else :
82
- return symbolic_operator .symbolic_operator ('numpy.add' , * args , ** kwargs )
64
+ return symbolic_f (lax_reference .add , 'numpy.add' , * args , ** kwargs )
83
65
84
66
85
67
def symbolic_sub (* args , ** kwargs ):
86
- if all_concrete_values ([* args ]):
87
- return lax_reference .sub (* args , ** kwargs )
88
- else :
89
- return symbolic_operator .symbolic_operator ('numpy.subtract' , * args , ** kwargs )
68
+ return symbolic_f (lax_reference .sub , 'numpy.subtract' , * args , ** kwargs )
90
69
91
70
92
71
def symbolic_mul (* args , ** kwargs ):
93
- if all_concrete_values ([* args ]):
94
- return lax_reference .mul (* args , ** kwargs )
95
- else :
96
- return symbolic_operator .symbolic_operator ('numpy.multiply' , * args , ** kwargs )
72
+ return symbolic_f (lax_reference .mul , 'numpy.multiply' , * args , ** kwargs )
97
73
98
74
99
75
def symbolic_div (* args , ** kwargs ):
100
- if all_concrete_values ([* args ]):
101
- return lax_reference .div (* args , ** kwargs )
102
- else :
103
- return symbolic_operator .symbolic_operator ('lax_reference.div' , * args , ** kwargs )
76
+ return symbolic_f (lax_reference .div , 'lax_reference.div' , * args , ** kwargs )
104
77
105
78
106
79
def symbolic_max (* args , ** kwargs ):
107
- if all_concrete_values ([* args ]):
108
- return lax_reference .max (* args , ** kwargs )
109
- else :
110
- r = symbolic_operator .symbolic_operator ('numpy.maximum' , * args , ** kwargs )
111
- return r
80
+ return symbolic_f (lax_reference .max , 'numpy.maximum' , * args , ** kwargs )
112
81
113
82
114
83
def symbolic_min (* args , ** kwargs ):
115
- if all_concrete_values ([* args ]):
116
- return lax_reference .min (* args , ** kwargs )
117
- else :
118
- return symbolic_operator .symbolic_operator ('numpy.minimum' , * args , ** kwargs )
119
-
120
-
121
- def symbolic_select_n (* args , ** kwargs ):
122
- '''
123
- Important comment from lax.py
124
- # Caution! The select_n_p primitive has the *opposite* order of arguments to
125
- # select(). This is because it implements `select_n`.
126
- '''
127
- pred = args [0 ]
128
- on_true = args [1 ]
129
- on_false = args [2 ]
130
- if all_concrete_values ([* args ]):
131
- # swap order of on_true and on_false
132
- return lax_reference .select (pred , on_false , on_true )
133
- else :
134
- # swap order of on_true and on_false
135
- # TODO: need a more general solution to unquoting symbolic strings
136
- evaluable_pred = symbolic_representation .symbolic_representation (pred )
137
- evaluable_on_true = symbolic_representation .symbolic_representation (on_true )
138
- evaluable_on_false = symbolic_representation .symbolic_representation (on_false )
139
- return f'lax_reference.select({ evaluable_pred } , { evaluable_on_false } , { evaluable_on_true } )'
84
+ return symbolic_f (lax_reference .min , 'numpy.minimum' , * args , ** kwargs )
140
85
141
86
142
87
def symbolic_and (* args , ** kwargs ):
143
- if all_concrete_values ([* args ]):
144
- return numpy .logical_and (* args , ** kwargs )
145
- else :
146
- return symbolic_operator .symbolic_operator ('numpy.logical_and' , * args , ** kwargs )
88
+ return symbolic_f (numpy .logical_and , 'numpy.logical_and' , * args , ** kwargs )
147
89
148
90
149
91
def symbolic_or (* args , ** kwargs ):
150
- if all_concrete_values ([* args ]):
151
- return numpy .logical_or (* args , ** kwargs )
152
- else :
153
- return symbolic_operator .symbolic_operator ('numpy.logical_or' , * args , ** kwargs )
92
+ return symbolic_f (numpy .logical_or , 'numpy.logical_or' , * args , ** kwargs )
154
93
155
94
156
95
def symbolic_xor (* args , ** kwargs ):
157
- if all_concrete_values ([* args ]):
158
- return numpy .logical_xor (* args , ** kwargs )
159
- else :
160
- return symbolic_operator .symbolic_operator ('numpy.logical_xor' , * args , ** kwargs )
96
+ return symbolic_f (numpy .logical_xor , 'numpy.logical_xor' , * args , ** kwargs )
161
97
162
98
163
99
def symbolic_sum (* args , ** kwargs ):
164
- if all_concrete_values ([* args ]):
165
- return lax_reference .sum (* args , ** kwargs )
166
- else :
167
- return symbolic_operator .symbolic_operator ('lax_reference.sum' , * args , ** kwargs )
100
+ return symbolic_f (lax_reference .sum , 'lax_reference.sum' , * args , ** kwargs )
168
101
169
102
170
103
def symbolic_broadcast_in_dim (* args , ** kwargs ):
@@ -194,6 +127,29 @@ def convert_element_type(x, dtype):
194
127
return convert_element_type (* args , dtype = kwargs ['new_dtype' ])
195
128
196
129
130
+ def symbolic_select_n (* args , ** kwargs ):
131
+ '''
132
+ Important comment from lax.py
133
+ # Caution! The select_n_p primitive has the *opposite* order of arguments to
134
+ # select(). This is because it implements `select_n`.
135
+ '''
136
+ pred = args [0 ]
137
+ on_true = args [1 ]
138
+ on_false = args [2 ]
139
+ if all_concrete_values ([* args ]):
140
+ # swap order of on_true and on_false
141
+ return lax_reference .select (pred , on_false , on_true )
142
+ else :
143
+ # swap order of on_true and on_false
144
+ # TODO: need a more general solution to unquoting symbolic strings
145
+ evaluable_pred = symbolic_representation .symbolic_representation (pred )
146
+ evaluable_on_true = symbolic_representation .symbolic_representation (
147
+ on_true )
148
+ evaluable_on_false = symbolic_representation .symbolic_representation (
149
+ on_false )
150
+ return f'lax_reference.select({ evaluable_pred } , { evaluable_on_false } , { evaluable_on_true } )'
151
+
152
+
197
153
def make_symbolic_reducer (py_binop , init_val ):
198
154
# This function is a hack to get around the fact that JAX doesn't
199
155
# support symbolic reduction operations. It takes a symbolic reduction
0 commit comments