5
5
import jax ._src .lax_reference as lax_reference
6
6
import jaxlib
7
7
8
+
8
9
def convert_iterable_type (x : list , new_type ):
9
10
if new_type == list :
10
11
return x
@@ -15,51 +16,62 @@ def convert_iterable_type(x: list, new_type):
15
16
elif new_type == jaxlib .xla_extension .DeviceArray :
16
17
return jax .numpy .array (x , dtype = object )
17
18
else :
18
- raise NotImplementedError (f"Cannot convert type { type (x )} to type { new_type } " )
19
+ raise NotImplementedError (
20
+ f"Cannot convert type { type (x )} to type { new_type } " )
21
+
19
22
20
23
@dispatch
21
24
def map_at_elements (x : list , func : typing .Callable ):
22
25
return convert_iterable_type ([map_at_elements (item , func ) for item in x ], type (x ))
23
26
27
+
24
28
@dispatch
25
29
def map_at_elements (x : numpy .ndarray , func : typing .Callable ):
26
30
return convert_iterable_type ([map_at_elements (item , func ) for item in x ], type (x ))
27
31
32
+
28
33
@dispatch
29
34
def map_at_elements (x : jax .numpy .ndarray , func : typing .Callable ):
30
35
if x .ndim == 0 :
31
36
return func (x .item ())
32
37
return convert_iterable_type ([map_at_elements (item , func ) for item in x ], type (x ))
33
38
39
+
34
40
@dispatch
35
41
def map_at_elements (x : str , func : typing .Callable ):
36
42
return func (x )
37
43
44
+
38
45
@dispatch
39
46
def map_at_elements (x , func : typing .Callable ):
40
47
return func (x )
41
48
49
+
42
50
@dispatch
43
51
def to_boolean_value_string (x : bool ):
44
52
return 'True' if x else 'False'
45
53
54
+
46
55
@dispatch
47
56
def to_boolean_value_string (x : numpy .bool_ ):
48
57
return 'True' if x else 'False'
49
58
59
+
50
60
@dispatch
51
61
def to_boolean_value_string (x : int ):
52
62
return 'True' if x == 1.0 else 'False'
53
63
64
+
54
65
@dispatch
55
66
def to_boolean_value_string (x : float ):
56
67
return 'True' if x == 1.0 else 'False'
57
68
69
+
58
70
@dispatch
59
71
def to_boolean_value_string (x : str ):
60
- if x == '1' or x == '1.0' or x == 'True' :
72
+ if x == '1' or x == '1.0' or x == 'True' :
61
73
return 'True'
62
- elif x == '0' or x == '0.0' or x == 'False' :
74
+ elif x == '0' or x == '0.0' or x == 'False' :
63
75
return 'False'
64
76
else :
65
77
return x
@@ -86,7 +98,6 @@ def unary_operator(operator: str, x: list):
86
98
87
99
@dispatch
88
100
def binary_infix_operator (operator : str , a : str , b : str , bracket : bool = False ) -> str :
89
- # We need to specify bracket because Python cannot evaluate expressions with too many nested parantheses
90
101
if bracket :
91
102
return f"({ a } ) { operator } ({ b } )"
92
103
return f"{ a } { operator } { b } "
@@ -160,17 +171,17 @@ def symbolic_sum(*args, **kwargs):
160
171
else :
161
172
return binary_infix_operator ("+" , * args , ** kwargs )
162
173
163
- # Uses the lax reference implementation of broadcast_in_dim to
164
- # implement a symbolic version of broadcast_in_dim
165
-
166
174
167
175
def symbolic_broadcast_in_dim (* args , ** kwargs ):
176
+ # Uses the lax reference implementation of broadcast_in_dim to
177
+ # implement a symbolic version of broadcast_in_dim
168
178
return lax_reference .broadcast_in_dim (* args , ** kwargs )
169
179
170
180
171
181
def symbolic_convert_element_type_impl (x , dtype ):
172
182
if dtype == numpy .int32 or dtype == numpy .int64 :
173
183
dtype = "int"
184
+
174
185
def convert (x ):
175
186
return f"{ dtype } ({ x } )"
176
187
return map_at_elements (x , convert )
@@ -187,13 +198,13 @@ def symbolic_convert_element_type(*args, **kwargs):
187
198
return symbolic_convert_element_type_impl (* args , dtype = kwargs ['new_dtype' ])
188
199
189
200
190
- # This function is a hack to get around the fact that JAX doesn't
191
- # support symbolic reduction operations. It takes a symbolic reduction
192
- # operation and a symbolic initial value and returns a function that
193
- # performs the reduction operation on a numpy array.
194
201
195
202
196
203
def make_symbolic_reducer (py_binop , init_val ):
204
+ # This function is a hack to get around the fact that JAX doesn't
205
+ # support symbolic reduction operations. It takes a symbolic reduction
206
+ # operation and a symbolic initial value and returns a function that
207
+ # performs the reduction operation on a numpy array.
197
208
def reducer (operand , axis ):
198
209
# axis=None means we are reducing over all axes of the operand.
199
210
axis = range (numpy .ndim (operand )) if axis is None else axis
0 commit comments