|
| 1 | +import jax |
| 2 | +import jax._src.lax_reference as lax_reference |
| 3 | +from jax import core |
| 4 | +from jax._src.util import safe_map |
| 5 | +import numpy |
| 6 | +from neurallogic import symbolic_primitives |
| 7 | + |
| 8 | + |
| 9 | +def symbolic_bind(prim, *args, **params): |
| 10 | + #print("\n---symbolic_bind:") |
| 11 | + #print("primitive: ", prim.name) |
| 12 | + #print("args: ", args) |
| 13 | + #print("params: ", params) |
| 14 | + symbolic_outvals = { |
| 15 | + 'and': symbolic_primitives.symbolic_and, |
| 16 | + 'broadcast_in_dim': symbolic_primitives.symbolic_broadcast_in_dim, |
| 17 | + 'xor': symbolic_primitives.symbolic_xor, |
| 18 | + 'not': symbolic_primitives.symbolic_not, |
| 19 | + 'reshape': lax_reference.reshape, |
| 20 | + 'reduce_or': symbolic_primitives.symbolic_reduce_or, |
| 21 | + 'reduce_sum': symbolic_primitives.symbolic_reduce_sum, |
| 22 | + 'convert_element_type': symbolic_primitives.symbolic_convert_element_type |
| 23 | + }[prim.name](*args, **params) |
| 24 | + return symbolic_outvals |
| 25 | + |
| 26 | + |
| 27 | + |
| 28 | +def eval_jaxpr(symbolic, jaxpr, consts, *args): |
| 29 | + """Evaluates a jaxpr by interpreting it as Python code. |
| 30 | +
|
| 31 | + Parameters |
| 32 | + ---------- |
| 33 | + symbolic : bool |
| 34 | + Whether to return symbolic values or concrete values. If symbolic is |
| 35 | + True, returns symbolic values, and if symbolic is False, returns |
| 36 | + concrete values. |
| 37 | + jaxpr : Jaxpr |
| 38 | + The jaxpr to interpret. |
| 39 | + consts : tuple |
| 40 | + Constant values for the jaxpr. |
| 41 | + args : tuple |
| 42 | + Arguments for the jaxpr. |
| 43 | +
|
| 44 | + Returns |
| 45 | + ------- |
| 46 | + out : tuple |
| 47 | + The result of evaluating the jaxpr. |
| 48 | + """ |
| 49 | + |
| 50 | + # Mapping from variable -> value |
| 51 | + env = {} |
| 52 | + symbolic_env = {} |
| 53 | + |
| 54 | + def read(var): |
| 55 | + # Literals are values baked into the Jaxpr |
| 56 | + if type(var) is core.Literal: |
| 57 | + return var.val |
| 58 | + return env[var] |
| 59 | + |
| 60 | + def symbolic_read(var): |
| 61 | + return symbolic_env[var] |
| 62 | + |
| 63 | + def write(var, val): |
| 64 | + env[var] = val |
| 65 | + |
| 66 | + def symbolic_write(var, val): |
| 67 | + symbolic_env[var] = val |
| 68 | + |
| 69 | + # Bind args and consts to environment |
| 70 | + if not symbolic: |
| 71 | + safe_map(write, jaxpr.invars, args) |
| 72 | + safe_map(write, jaxpr.constvars, consts) |
| 73 | + safe_map(symbolic_write, jaxpr.invars, args) |
| 74 | + safe_map(symbolic_write, jaxpr.constvars, consts) |
| 75 | + |
| 76 | + def eval_jaxpr_impl(jaxpr): |
| 77 | + # Loop through equations and evaluate primitives using `bind` |
| 78 | + for eqn in jaxpr.eqns: |
| 79 | + # Read inputs to equation from environment |
| 80 | + if not symbolic: |
| 81 | + invals = safe_map(read, eqn.invars) |
| 82 | + symbolic_invals = safe_map(symbolic_read, eqn.invars) |
| 83 | + prim = eqn.primitive |
| 84 | + if type(prim) is jax.core.CallPrimitive: |
| 85 | + # print(f"call primitive: {prim.name}") |
| 86 | + call_jaxpr = eqn.params['call_jaxpr'] |
| 87 | + if not symbolic: |
| 88 | + safe_map(write, call_jaxpr.invars, map(read, eqn.invars)) |
| 89 | + safe_map(symbolic_write, call_jaxpr.invars, |
| 90 | + map(symbolic_read, eqn.invars)) |
| 91 | + eval_jaxpr_impl(call_jaxpr) |
| 92 | + if not symbolic: |
| 93 | + safe_map(write, eqn.outvars, map(read, call_jaxpr.outvars)) |
| 94 | + safe_map(symbolic_write, eqn.outvars, map( |
| 95 | + symbolic_read, call_jaxpr.outvars)) |
| 96 | + else: |
| 97 | + # print(f"primitive: {prim.name}") |
| 98 | + if not symbolic: |
| 99 | + outvals = prim.bind(*invals, **eqn.params) |
| 100 | + symbolic_outvals = symbolic_bind( |
| 101 | + prim, *symbolic_invals, **eqn.params) |
| 102 | + # Primitives may return multiple outputs or not |
| 103 | + if not prim.multiple_results: |
| 104 | + if not symbolic: |
| 105 | + outvals = [outvals] |
| 106 | + symbolic_outvals = [symbolic_outvals] |
| 107 | + if not symbolic: |
| 108 | + #print(f"outvals: {type(outvals)}: {outvals}") |
| 109 | + #print( |
| 110 | + # f"symbolic_outvals: {type(symbolic_outvals)}: {symbolic_outvals}") |
| 111 | + # Check that the concrete and symbolic values are equal |
| 112 | + assert numpy.array_equal( |
| 113 | + numpy.array(outvals), symbolic_outvals) |
| 114 | + # Write the results of the primitive into the environment |
| 115 | + if not symbolic: |
| 116 | + safe_map(write, eqn.outvars, outvals) |
| 117 | + safe_map(symbolic_write, eqn.outvars, symbolic_outvals) |
| 118 | + |
| 119 | + # Read the final result of the Jaxpr from the environment |
| 120 | + eval_jaxpr_impl(jaxpr) |
| 121 | + if not symbolic: |
| 122 | + return safe_map(read, jaxpr.outvars)[0] |
| 123 | + else: |
| 124 | + return safe_map(symbolic_read, jaxpr.outvars)[0] |
| 125 | + |
| 126 | + |
| 127 | +def eval_jaxpr_concrete(jaxpr, *args): |
| 128 | + return eval_jaxpr(False, jaxpr.jaxpr, jaxpr.literals, *args) |
| 129 | + |
| 130 | + |
| 131 | +def eval_jaxpr_symbolic(jaxpr, *args): |
| 132 | + symbolic_jaxpr_literals = safe_map(lambda x: numpy.array(x, dtype=object), jaxpr.literals) |
| 133 | + symbolic_jaxpr_literals = symbolic_primitives.to_boolean_symbolic_values(symbolic_jaxpr_literals) |
| 134 | + return eval_jaxpr(True, jaxpr.jaxpr, symbolic_jaxpr_literals, *args) |
| 135 | + |
0 commit comments