@@ -3986,7 +3986,150 @@ def c_code(self, *args, **kwargs):
39863986complex_from_polar = ComplexFromPolar (name = "complex_from_polar" )
39873987
39883988
3989- class Composite (ScalarOp , HasInnerGraph ):
3989+ class ScalarInnerGraphOp (ScalarOp , HasInnerGraph ):
3990+ """Includes boilerplate code for Python and C-implementation of Scalar Ops with inner graph."""
3991+
3992+ def __init__ (self , * args , ** kwargs ):
3993+ self .prepare_node_called = set ()
3994+
3995+ @property
3996+ def fn (self ):
3997+ return None
3998+
3999+ @property
4000+ def inner_inputs (self ):
4001+ return self .fgraph .inputs
4002+
4003+ @property
4004+ def inner_outputs (self ):
4005+ return self .fgraph .outputs
4006+
4007+ @property
4008+ def py_perform_fn (self ):
4009+ if hasattr (self , "_py_perform_fn" ):
4010+ return self ._py_perform_fn
4011+
4012+ from pytensor .link .utils import fgraph_to_python
4013+
4014+ def python_convert (op , node = None , ** kwargs ):
4015+ assert node is not None
4016+
4017+ n_outs = len (node .outputs )
4018+
4019+ if n_outs > 1 :
4020+
4021+ def _perform (* inputs , outputs = [[None ]] * n_outs ):
4022+ op .perform (node , inputs , outputs )
4023+ return tuple (o [0 ] for o in outputs )
4024+
4025+ else :
4026+
4027+ def _perform (* inputs , outputs = [[None ]]):
4028+ op .perform (node , inputs , outputs )
4029+ return outputs [0 ][0 ]
4030+
4031+ return _perform
4032+
4033+ self ._py_perform_fn = fgraph_to_python (self .fgraph , python_convert )
4034+ return self ._py_perform_fn
4035+
4036+ def impl (self , * inputs ):
4037+ output_storage = [[None ] for i in range (self .nout )]
4038+ self .perform (None , inputs , output_storage )
4039+ ret = to_return_values ([storage [0 ] for storage in output_storage ])
4040+ if self .nout > 1 :
4041+ ret = tuple (ret )
4042+ return ret
4043+
4044+ def c_code_cache_version (self ):
4045+ rval = list (self .c_code_cache_version_outer ())
4046+ for x in self .fgraph .toposort ():
4047+ xv = x .op .c_code_cache_version ()
4048+ if xv :
4049+ rval .append (xv )
4050+ else :
4051+ return ()
4052+ return tuple (rval )
4053+
4054+ def c_header_dirs (self , ** kwargs ):
4055+ rval = sum (
4056+ (subnode .op .c_header_dirs (** kwargs ) for subnode in self .fgraph .toposort ()),
4057+ [],
4058+ )
4059+ return rval
4060+
4061+ def c_support_code (self , ** kwargs ):
4062+ # Remove duplicate code blocks by using a `set`
4063+ rval = {
4064+ subnode .op .c_support_code (** kwargs ).strip ()
4065+ for subnode in self .fgraph .toposort ()
4066+ }
4067+ return "\n " .join (sorted (rval ))
4068+
4069+ def c_support_code_apply (self , node , name ):
4070+ rval = []
4071+ for subnode , subnodename in zip (self .fgraph .toposort (), self .nodenames ):
4072+ subnode_support_code = subnode .op .c_support_code_apply (
4073+ subnode , subnodename % dict (nodename = name )
4074+ )
4075+ if subnode_support_code :
4076+ rval .append (subnode_support_code )
4077+ # there should be no need to remove duplicate code blocks because
4078+ # each block should have been specialized for the given nodename.
4079+ # Any block that isn't specialized should be returned via
4080+ # c_support_code instead of c_support_code_apply.
4081+ return "\n " .join (rval )
4082+
4083+ def prepare_node (self , node , storage_map , compute_map , impl ):
4084+ if impl not in self .prepare_node_called :
4085+ for n in list_of_nodes (self .inputs , self .outputs ):
4086+ n .op .prepare_node (n , None , None , impl )
4087+ self .prepare_node_called .add (impl )
4088+
4089+ def __eq__ (self , other ):
4090+ if self is other :
4091+ return True
4092+ if (
4093+ type (self ) != type (other )
4094+ or self .nin != other .nin
4095+ or self .nout != other .nout
4096+ ):
4097+ return False
4098+
4099+ # TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
4100+ # object to generate the same `_c_code`?
4101+ return self .c_code_template == other .c_code_template
4102+
4103+ def __hash__ (self ):
4104+ # Note that in general, the configparser settings at the time
4105+ # of code generation (__init__) affect the semantics of this Op.
4106+ # This function assumes that all relevant info about the configparser
4107+ # is embodied in _c_code. So the _c_code, rather than self.fgraph,
4108+ # is the signature of the semantics of this Op.
4109+ # _c_code is preserved through unpickling, so the Op will not change
4110+ # semantics when it is reloaded with different configparser
4111+ # settings.
4112+ #
4113+ # TODO FIXME: Doesn't the above just mean that we should be including
4114+ # the relevant "configparser settings" here? Also, why should we even
4115+ # care about the exact form of the generated C code when comparing
4116+ # `Op`s? All this smells of leaky concerns and interfaces.
4117+ return hash ((type (self ), self .nin , self .nout , self .c_code_template ))
4118+
4119+ def __getstate__ (self ):
4120+ rval = dict (self .__dict__ )
4121+ rval .pop ("_c_code" , None )
4122+ rval .pop ("_py_perform_fn" , None )
4123+ rval .pop ("_fgraph" , None )
4124+ rval .pop ("prepare_node_called" , None )
4125+ return rval
4126+
4127+ def __setstate__ (self , d ):
4128+ self .__dict__ .update (d )
4129+ self .prepare_node_called = set ()
4130+
4131+
4132+ class Composite (ScalarInnerGraphOp ):
39904133 """
39914134 Composite is an Op that takes a graph of scalar operations and
39924135 produces c code for the whole graph. Its purpose is to implement loop
@@ -4043,19 +4186,7 @@ def __init__(self, inputs, outputs, name="Composite"):
40434186 self .outputs_type = tuple ([output .type for output in outputs ])
40444187 self .nin = len (inputs )
40454188 self .nout = len (outputs )
4046- self .prepare_node_called = set ()
4047-
4048- @property
4049- def fn (self ):
4050- return None
4051-
4052- @property
4053- def inner_inputs (self ):
4054- return self .fgraph .inputs
4055-
4056- @property
4057- def inner_outputs (self ):
4058- return self .fgraph .outputs
4189+ super ().__init__ ()
40594190
40604191 def __str__ (self ):
40614192 return self .name
@@ -4076,35 +4207,6 @@ def make_new_inplace(self, output_types_preference=None, name=None):
40764207 super (Composite , out ).__init__ (output_types_preference , name )
40774208 return out
40784209
4079- @property
4080- def py_perform (self ):
4081- if hasattr (self , "_py_perform_fn" ):
4082- return self ._py_perform_fn
4083-
4084- from pytensor .link .utils import fgraph_to_python
4085-
4086- def python_convert (op , node = None , ** kwargs ):
4087- assert node is not None
4088-
4089- n_outs = len (node .outputs )
4090-
4091- if n_outs > 1 :
4092-
4093- def _perform (* inputs , outputs = [[None ]] * n_outs ):
4094- op .perform (node , inputs , outputs )
4095- return tuple (o [0 ] for o in outputs )
4096-
4097- else :
4098-
4099- def _perform (* inputs , outputs = [[None ]]):
4100- op .perform (node , inputs , outputs )
4101- return outputs [0 ][0 ]
4102-
4103- return _perform
4104-
4105- self ._py_perform_fn = fgraph_to_python (self .fgraph , python_convert )
4106- return self ._py_perform_fn
4107-
41084210 @property
41094211 def fgraph (self ):
41104212 if hasattr (self , "_fgraph" ):
@@ -4139,12 +4241,6 @@ def fgraph(self):
41394241 self ._fgraph = fgraph
41404242 return self ._fgraph
41414243
4142- def prepare_node (self , node , storage_map , compute_map , impl ):
4143- if impl not in self .prepare_node_called :
4144- for n in list_of_nodes (self .inputs , self .outputs ):
4145- n .op .prepare_node (n , None , None , impl )
4146- self .prepare_node_called .add (impl )
4147-
41484244 def clone_float32 (self ):
41494245 # This will not modify the fgraph or the nodes
41504246 new_ins , new_outs = composite_f32 .apply (self .fgraph )
@@ -4155,8 +4251,6 @@ def clone(self):
41554251 return Composite (new_ins , new_outs )
41564252
41574253 def output_types (self , input_types ):
4158- # TODO FIXME: What's the intended purpose/use of this method, and why
4159- # does it even need to be a method?
41604254 if tuple (input_types ) != self .inputs_type :
41614255 raise TypeError (
41624256 f"Wrong types for Composite. Expected { self .inputs_type } , got { tuple (input_types )} ."
@@ -4183,63 +4277,13 @@ def make_node(self, *inputs):
41834277 return node
41844278
41854279 def perform (self , node , inputs , output_storage ):
4186- outputs = self .py_perform (* inputs )
4280+ outputs = self .py_perform_fn (* inputs )
41874281 for storage , out_val in zip (output_storage , outputs ):
41884282 storage [0 ] = out_val
41894283
4190- def impl (self , * inputs ):
4191- output_storage = [[None ] for i in range (self .nout )]
4192- self .perform (None , inputs , output_storage )
4193- ret = to_return_values ([storage [0 ] for storage in output_storage ])
4194- if self .nout > 1 :
4195- ret = tuple (ret )
4196- return ret
4197-
41984284 def grad (self , inputs , output_grads ):
41994285 raise NotImplementedError ("grad is not implemented for Composite" )
42004286
4201- def __eq__ (self , other ):
4202- if self is other :
4203- return True
4204- if (
4205- type (self ) != type (other )
4206- or self .nin != other .nin
4207- or self .nout != other .nout
4208- ):
4209- return False
4210-
4211- # TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
4212- # object to generate the same `_c_code`?
4213- return self .c_code_template == other .c_code_template
4214-
4215- def __hash__ (self ):
4216- # Note that in general, the configparser settings at the time
4217- # of code generation (__init__) affect the semantics of this Op.
4218- # This function assumes that all relevant info about the configparser
4219- # is embodied in _c_code. So the _c_code, rather than self.fgraph,
4220- # is the signature of the semantics of this Op.
4221- # _c_code is preserved through unpickling, so the Op will not change
4222- # semantics when it is reloaded with different configparser
4223- # settings.
4224- #
4225- # TODO FIXME: Doesn't the above just mean that we should be including
4226- # the relevant "configparser settings" here? Also, why should we even
4227- # care about the exact form of the generated C code when comparing
4228- # `Op`s? All this smells of leaky concerns and interfaces.
4229- return hash ((type (self ), self .nin , self .nout , self .c_code_template ))
4230-
4231- def __getstate__ (self ):
4232- rval = dict (self .__dict__ )
4233- rval .pop ("_c_code" , None )
4234- rval .pop ("_py_perform_fn" , None )
4235- rval .pop ("_fgraph" , None )
4236- rval .pop ("prepare_node_called" , None )
4237- return rval
4238-
4239- def __setstate__ (self , d ):
4240- self .__dict__ .update (d )
4241- self .prepare_node_called = set ()
4242-
42434287 @property
42444288 def c_code_template (self ):
42454289 from pytensor .link .c .interface import CLinkerType
@@ -4317,44 +4361,8 @@ def c_code(self, node, nodename, inames, onames, sub):
43174361
43184362 return self .c_code_template % d
43194363
4320- def c_code_cache_version (self ):
4321- rval = [3 ]
4322- for x in self .fgraph .toposort ():
4323- xv = x .op .c_code_cache_version ()
4324- if xv :
4325- rval .append (xv )
4326- else :
4327- return ()
4328- return tuple (rval )
4329-
4330- def c_header_dirs (self , ** kwargs ):
4331- rval = sum (
4332- (subnode .op .c_header_dirs (** kwargs ) for subnode in self .fgraph .toposort ()),
4333- [],
4334- )
4335- return rval
4336-
4337- def c_support_code (self , ** kwargs ):
4338- # Remove duplicate code blocks by using a `set`
4339- rval = {
4340- subnode .op .c_support_code (** kwargs ).strip ()
4341- for subnode in self .fgraph .toposort ()
4342- }
4343- return "\n " .join (sorted (rval ))
4344-
4345- def c_support_code_apply (self , node , name ):
4346- rval = []
4347- for subnode , subnodename in zip (self .fgraph .toposort (), self .nodenames ):
4348- subnode_support_code = subnode .op .c_support_code_apply (
4349- subnode , subnodename % dict (nodename = name )
4350- )
4351- if subnode_support_code :
4352- rval .append (subnode_support_code )
4353- # there should be no need to remove duplicate code blocks because
4354- # each block should have been specialized for the given nodename.
4355- # Any block that isn't specialized should be returned via
4356- # c_support_code instead of c_support_code_apply.
4357- return "\n " .join (rval )
4364+ def c_code_cache_version_outer (self ) -> Tuple [int , ...]:
4365+ return (3 ,)
43584366
43594367
43604368class Compositef32 :
0 commit comments