source: coopr.pyomo/trunk/coopr/pyomo/base/expr.py @ 2201

Last change on this file since 2201 was 2201, checked in by wehart, 10 years ago

Update to Coopr to account for changes in PyUtilib? package names.

  • Property svn:executable set to *
File size: 23.1 KB
Line 
1#  _________________________________________________________________________
2#
3#  Coopr: A COmmon Optimization Python Repository
4#  Copyright (c) 2008 Sandia Corporation.
5#  This software is distributed under the BSD License.
6#  Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
7#  the U.S. Government retains certain rights in this software.
8#  For more information, see the Coopr README.txt file.
9#  _________________________________________________________________________
10
11__all__ = ['Expression', '_LessThanExpression', '_GreaterThanExpression',
12        '_LessThanOrEqualExpression', '_GreaterThanOrEqualExpression',
13        '_EqualToExpression', '_IdentityExpression' , 'generate_expression']
14
15from plugin import *
16from pyutilib.component.core import *
17from numvalue import *
18from param import _ParamBase
19from var import _VarBase
20
21import sys
22import copy
23
24
25class Expression(NumericValue):
26    """An object that defines a mathematical expression that can be evaluated"""
27
28    def __init__(self, nargs=-1, name="UNKNOWN", operation=None, args=None, tuple_op=False):
29        """Construct an expression with an operation and a set of arguments"""
30        NumericValue.__init__(self, name=name)
31        self._args=args
32        self._nargs=nargs
33        self._operation=operation
34        self._tuple_op=tuple_op
35        self.verify()
36
37    def pprint(self, ostream=None, nested=True, eol_flag=True):
38        """Print this expression"""
39        if ostream is None:
40           ostream = sys.stdout
41        if nested:
42           print >>ostream, self.name + "(",
43           first=True
44           for arg in self._args:
45             if first==False:
46                print >>ostream, ",",
47             if isinstance(arg,Expression):
48                arg.pprint(ostream=ostream, nested=nested, eol_flag=False)
49             else:
50                print >>ostream, str(arg),
51             first=False
52           if eol_flag==True:
53              print >>ostream, ")"
54           else:
55              print >>ostream, ")",
56
57    def clone(self, args=()):
58        """Clone this object using the specified arguments"""
59        if self.__class__ == Expression:
60           return Expression(nargs=self._nargs, name=self.name, operation=self._operation, tuple_op=self._tuple_op, args=args)
61        return self.__class__(args=args)
62
63    def verify(self):
64        """Throw an exception if the list of arguments differs from the
65           specified number of arguments allowed"""
66        if self._args is None:
67           return
68        if (self._nargs != -1) and (self._nargs != len(self._args)):    #pragma:nocover
69           raise ValueError, "There were " + `len(self._args)` + " arguments specified for expression " + self.name + " but this expression requires " + `self._nargs` + " arguments"
70        for arg in self._args:
71            if type(arg) is float:      #pragma:nocover
72                raise ValueError, "Argument for expression "+self.name+" is a float!"
73            if (isinstance(arg,_ParamBase) or isinstance(arg,_VarBase)) and arg.dim() > 0:
74                raise ValueError, "Argument for expression "+self.name+" is an n-ary numeric value: "+arg.name
75
76    #
77    # this method contrast with the fixed_value() method.
78    # the fixed_value() method returns true iff the value is
79    # an atomic constant.
80    # this method returns true iff all composite arguments
81    # in this sum expression are constant, i.e., numeric
82    # constants or parametrs. the parameter values can of
83    # course change over time, but at any point in time,
84    # they are constant. hence, the name.
85    #
86    def is_constant(self):
87        for arg in self._args:
88            if not arg.is_constant():
89                return False
90        return True
91
92    def simplify(self, model):
93        """ Walk through the S-expression, performing simplifications, and
94            replacing values of parameters with constants.
95        """
96        for i in range(0,len(self._args)):
97          if isinstance(self._args[i],Expression):
98             self._args[i] = self._args[i].simplify(model)
99        return self
100
101    def __call__(self, exception=True):
102        """Evaluate the expression"""
103        values=[]
104        for arg in self._args:
105          try:
106            val = value(arg)
107          except ValueError, e:
108            if exception:
109                raise ValueError, "Error evaluating expression: %s" % str(e)
110            return None
111          #try:
112          #  val = value(arg)
113          #except AttributeError:
114          #  return None
115          if val is None:
116             return None
117          values.append( val )
118        return self._apply_operation(values)
119
120    def _apply_operation(self, values):
121        """Method that can be overwritten to re-define the operation in this expression"""
122        if self._tuple_op:
123           tmp=tuple(values)
124           return self._operation(*tmp)
125        else:
126           return self._operation(values)
127       
128    def __str__(self):
129        return self.name
130
131
132class _LessThanExpression(Expression):
133    """An object that defines a less-than expression"""
134
135    def __init__(self, args=()):
136        """Constructor"""
137        Expression.__init__(self,args=args,nargs=2,name='lt')
138
139    def _apply_operation(self, values):
140        """Method that defines the less-than operation"""
141        return values[0] < values[1]
142
143
144class _GreaterThanExpression(Expression):
145    """An object that defines a greater-than expression"""
146
147    def __init__(self, args=()):
148        """Constructor"""
149        Expression.__init__(self,args=args,nargs=2,name='gt')
150
151    def _apply_operation(self, values):
152        """Method that defines the greater-than operation"""
153        return values[0] > values[1]
154
155
156class _LessThanOrEqualExpression(Expression):
157    """An object that defines a less-than-or-equal expression"""
158
159    def __init__(self, args=()):
160        """Constructor"""
161        Expression.__init__(self,args=args,nargs=2,name='lt')
162
163    def _apply_operation(self, values):
164        """Method that defines the less-than-or-equal operation"""
165        return values[0] <= values[1]
166
167
168class _GreaterThanOrEqualExpression(Expression):
169    """An object that defines a greater-than-or-equal expression"""
170
171    def __init__(self, args=()):
172        """Constructor"""
173        Expression.__init__(self,args=args,nargs=2,name='gt')
174
175    def _apply_operation(self, values):
176        """Method that defines the greater-than-or-equal operation"""
177        return values[0] >= values[1]
178
179
180class _EqualToExpression(Expression):
181    """An object that defines a equal-to expression"""
182
183    def __init__(self, args=()):
184        """Constructor"""
185        Expression.__init__(self,args=args,nargs=2,name='eq')
186
187    def _apply_operation(self, values):
188        """Method that defines the equal-to operation"""
189        return values[0] == values[1]
190
191
192class _ProductExpression(Expression):
193    """An object that defines a product expression"""
194
195    def __init__(self, args=()):
196        """Constructor"""
197        self._denominator = []
198        self._numerator = list(args)
199        self.coef = 1
200        Expression.__init__(self,nargs=-1,name='prod')
201
202    def is_constant(self):
203        for arg in self._numerator:
204            if not arg.is_constant():
205                return False
206        for arg in self._denominator:
207            if not arg.is_constant():
208                return False
209        return True
210
211    def invert(self):
212        tmp = self._denominator
213        self._denominator = self._numerator
214        self._numerator = tmp
215        self.coef = 1.0/self.coef
216
217    def simplify(self, model):
218        #
219        # First, we apply the standard simplification of arguments
220        #
221        for i in range(0,len(self._numerator)):
222            if isinstance(self._numerator[i],Expression):
223                self._numerator[i] = self._numerator[i].simplify(model)
224        for i in range(0,len(self._denominator)):
225            if isinstance(self._denominator[i],Expression):
226                self._denominator[i] = self._denominator[i].simplify(model)
227        #
228        # Next, we collapse nested products
229        #
230        tmpnum = []
231        tmpdenom = []
232        tmpcoef = self.coef
233        for arg in self._numerator:
234            if isinstance(arg,_ProductExpression):
235                tmpnum = tmpnum + arg._numerator
236                tmpdenom = tmpdenom + arg._denominator
237                tmpcoef *= arg.coef
238            else:
239                tmpnum.append( arg )
240        for arg in self._denominator:
241            if isinstance(arg,_ProductExpression):
242                tmpnum = tmpnum + arg._denominator
243                tmpdenom = tmpdenom + arg._numerator
244                if arg.coef != 1:
245                    tmpcoef /= 1.0*arg.coef
246            else:
247                tmpdenom.append( arg )
248        if tmpcoef == 0:
249           return ZeroConstant
250        #
251        # Next, we eliminate constants
252        #
253        newnum = []
254        newdenom = []
255        newcoef = tmpcoef
256        for arg in tmpnum:
257            if type(arg) is NumericConstant:
258                if arg.value != 1:
259                    newcoef *= arg.value
260            else:
261                newnum.append( arg )
262        for arg in tmpdenom:
263            if type(arg) is NumericConstant:
264                if arg.value != 1:
265                    newcoef /= 1.0*arg.value
266            else:
267                newdenom.append( arg )
268        if newcoef == 0:
269            return ZeroConstant
270        #
271        # Return simplified expression
272        #
273        nargs = len(newnum)+len(newdenom)
274        if nargs == 1 and len(newnum) == 1:
275            if newcoef == 1:
276                return newnum[0]
277            if type(newnum[0]) is _SumExpression:
278                newnum[0].scale(newcoef)
279                return newnum[0]
280        if nargs == 0 and newcoef == 1:
281           return OneConstant
282        self._numerator = newnum
283        self._denominator = newdenom
284        self.coef = newcoef
285        return self
286
287    def add(self,numerator=None,denominator=None):
288       # print "TYPE",type(numerator),type(denominator)
289        if not numerator is None:
290            self._numerator.append(numerator)
291            self._nargs += 1
292        if not denominator is None:
293            self._denominator.append(denominator)
294            self._nargs += 1
295
296    def pprint(self, ostream=None, nested=True, eol_flag=True):
297        """Print this expression"""
298        if ostream is None:
299           ostream = sys.stdout
300        if nested:
301           print >>ostream, self.name + "( num=(",
302           first=True
303           if self.coef != 1:
304                print >>ostream, str(self.coef),
305                first=False
306           for arg in self._numerator:
307             if first==False:
308                print >>ostream, ",",
309             if isinstance(arg,Expression):
310                arg.pprint(ostream=ostream, nested=nested, eol_flag=False)
311             else:
312                print >>ostream, str(arg),
313             first=False
314           if first is True:
315              print >>ostream, 1,
316           print >>ostream, ")",
317           if len(self._denominator) > 0:
318              print >>ostream, ", denom=(",
319              first=True
320              for arg in self._denominator:
321                if first==False:
322                   print >>ostream, ",",
323                if isinstance(arg,Expression):
324                   arg.pprint(ostream=ostream, nested=nested, eol_flag=False)
325                else:
326                   print >>ostream, str(arg),
327                first=False
328              print >>ostream, ")",
329           print >>ostream, ")",
330           if eol_flag==True:
331              print >>ostream, ""
332
333    def clone(self, args=()):
334        """Clone this object using the specified arguments"""
335        tmp = self.__class__()
336        tmp.name = self.name
337        tmp.coef = self.coef
338        tmp._numerator = copy.copy(self._numerator)
339        tmp._denominator = copy.copy(self._denominator)
340        return tmp
341
342    def verify(self):
343        if self._denominator is None or self._numerator is None:
344           return
345        for arg in self._numerator:
346            if type(arg) is float:      #pragma:nocover
347                raise ValueError, "Argument for expression "+self.name+" is a float!"
348            if (isinstance(arg,_ParamBase) or isinstance(arg,_VarBase)) and arg.dim() > 0:
349                raise ValueError, "Argument for expression "+self.name+" is an n-ary numeric value: "+arg.name
350        for arg in self._denominator:
351            if type(arg) is float:      #pragma:nocover
352                raise ValueError, "Argument for expression "+self.name+" is a float!"
353            if (isinstance(arg,_ParamBase) or isinstance(arg,_VarBase)) and arg.dim() > 0:
354                raise ValueError, "Argument for expression "+self.name+" is an n-ary numeric value: "+arg.name
355
356    def __call__(self, exception=True):
357        """Evaluate the expression"""
358        ans = self.coef
359        for arg in self._numerator:
360            try:
361                val = value(arg)
362            except ValueError, e:
363                if exception:
364                    raise ValueError, "Error evaluating expression: %s" % str(e)
365                return None
366            if val is None:
367                return None
368            ans *= val
369        for arg in self._denominator:
370            try:
371                val = value(arg)
372            except ValueError, e:
373                if exception:
374                    raise ValueError, "Error evaluating expression: %s" % str(e)
375                return None
376            if val is None:
377                return None
378            if val != 1:
379                ans /= 1.0*val
380        return ans
381
382
383class _IdentityExpression(Expression):
384    """An object that defines a identity expression"""
385
386    def __init__(self, args=()):
387        """Constructor"""
388        if isinstance(args,list):
389           Expression.__init__(self,args=args,nargs=1,name='identity')
390        else:
391           Expression.__init__(self,args=[args],nargs=1,name='identity')
392
393    def _apply_operation(self, values):
394        """Method that defines the identity operation"""
395        return values[0]
396
397
398#class X_NegateExpression(Expression):
399    #"""An object that defines a negation expression"""
400#
401    #def __init__(self, args=()):
402        #"""Constructor"""
403        #Expression.__init__(self,args=args,nargs=1,name='negate')
404#
405    #def _apply_operation(self, values):
406        #"""Method that defines the negation operation"""
407        #return -values[0]
408
409
410class _AbsExpression(Expression):
411
412    def __init__(self, args=()):
413        Expression.__init__(self, args=args, nargs=1, name='abs', operation=abs, tuple_op=True)
414
415
416class _PowExpression(Expression):
417
418    def __init__(self, args=()):
419        Expression.__init__(self, args=args, nargs=2, name='pow', operation=pow, tuple_op=True)
420
421
422class _SumExpression(Expression):
423    """An object that defines a weighted summation of expressions"""
424
425    def __init__(self, args=(), coef=()):
426        """Constructor"""
427        Expression.__init__(self,args=list(args),nargs=-1,name='sum')
428        self._coef = list(coef)
429        self._const = 0
430
431    def clone(self, args=()):
432        """Clone this object using the specified arguments"""
433        tmp = self.__class__()
434        tmp.name = self.name
435        tmp._args = copy.copy(self._args)
436        tmp._coef = copy.copy(self._coef)
437        tmp._const = self._const
438        return tmp
439
440    def scale(self, val):
441        for i in range(len(self._coef)):
442            self._coef[i] *= val
443        self._const *= val
444
445    def negate(self):
446        self.scale(-1)
447
448    def simplify(self, model):
449        #
450        # First, we apply the standard simplification of arguments
451        #
452        Expression.simplify(self,model)
453        #
454        # Next, we collapse nested sums
455        #
456        tmpargs = []
457        tmpcoef = []
458        tmpconst = self._const
459        for i in range(len(self._args)):
460          arg = self._args[i]
461          if isinstance(arg,_SumExpression):
462             tmpargs = tmpargs + arg._args
463             tmpcoef = tmpcoef + map(lambda x:x*self._coef[i], arg._coef)
464             tmpconst += arg._const*self._coef[i]
465          else:
466             tmpargs.append( arg )
467             tmpcoef.append( self._coef[i] )
468        #
469        # Next, we simplify arguments
470        #
471        newargs = []
472        newcoef = []
473        newconst = tmpconst
474        for i in range(len(tmpargs)):
475            arg = tmpargs[i]
476            if type(arg) is NumericConstant:
477                if arg.value != 0:
478                    newconst += arg.value*tmpcoef[i]
479            elif isinstance(arg,_ProductExpression):
480                newcoef.append( arg.coef*tmpcoef[i] )
481                arg.coef = 1
482                newargs.append( arg )
483            else:
484                newargs.append( arg )
485                newcoef.append( tmpcoef[i] )
486        #
487        # Return simplified expression
488        #
489        if len(newargs) == 1 and newcoef[0] == 1 and newconst == 0:
490           return newargs[0]
491        elif len(newargs) > 0 or newconst != 0:
492            self._args = newargs
493            self._coef = newcoef
494            self._const = newconst
495            return self
496        else:
497           return ZeroConstant
498
499    def pprint(self, ostream=None, nested=True, eol_flag=True):
500        """Print this expression"""
501        if ostream is None:
502           ostream = sys.stdout
503        if nested:
504           print >>ostream, self.name + "(",
505           first=True
506           if self._const != 0:
507                print >>ostream, str(self._const),
508                first=False
509           for i in range(len(self._args)):
510             arg = self._args[i]
511             if first==False:
512                print >>ostream, ",",
513             if self._coef[i] != 1:
514                print >>ostream, str(self._coef[i])+" * ",
515             if isinstance(arg,Expression):
516                arg.pprint(ostream=ostream, nested=nested, eol_flag=False)
517             else:
518                print >>ostream, str(arg),
519             first=False
520           print >>ostream, ")",
521           if eol_flag==True:
522              print >>ostream, ""
523
524    def __call__(self, exception=True):
525        """Evaluate the expression"""
526        values=[]
527        for i in range(len(self._args)):
528            arg = self._args[i]
529            try:
530                val = value(arg)
531            except ValueError, e:
532                if exception:
533                    raise ValueError, "Error evaluating expression: %s" % str(e)
534                return None
535            if val is None:
536                return None
537            values.append( self._coef[i]*val )
538        return sum(values)+self._const
539
540    def add(self,coef=1,expr=None):
541        self._args.append(expr)
542        self._coef.append(coef)
543        self._nargs += 1
544
545
546def generate_expression(*_args):
547    #print "HERE",_args
548    etype = _args[0]
549
550    if etype is 'neg':
551        if type(_args[1]) is _SumExpression:
552            _args[1].negate()
553            return _args[1]
554        else:
555            etype = 'mul'
556            args=[_args[1],-1]
557    else:
558        args=_args[1:]
559
560    if etype is 'add':
561        #
562        # self + other
563        #
564        other = as_numeric(args[1])
565        if type(args[0]) is _SumExpression:
566            args[0].add(expr=other)
567            return args[0]
568        else:
569            tmp = _SumExpression()
570            tmp.add(expr=args[0])
571            tmp.add(expr=other)
572            return tmp
573
574    elif etype is 'radd':
575        #
576        # other + self
577        #
578        other = as_numeric(args[1])
579        if type(args[0]) is _SumExpression:
580            args[0].add(expr=other)
581            return args[0]
582        else:
583            tmp = _SumExpression()
584            tmp.add(expr=other)
585            tmp.add(expr=args[0])
586            return tmp
587
588    elif etype is 'sub':
589        #
590        # self - other
591        #
592        other = as_numeric(args[1])
593        if type(args[0]) is _SumExpression:
594            args[0].add(coef=-1, expr=other)
595            return args[0]
596        else:
597            tmp = _SumExpression()
598            tmp.add(expr=args[0])
599            tmp.add(coef=-1, expr=other)
600            return tmp
601
602    elif etype is 'rsub':
603        #
604        # other - self
605        #
606        other = as_numeric(args[1])
607        if type(args[0]) is _SumExpression:
608            args[0].negate()
609            args[0].add(expr=other)
610            return args[0]
611        else:
612            tmp = _SumExpression()
613            tmp.add(expr=other)
614            tmp.add(coef=-1, expr=args[0])
615            return tmp
616
617    elif etype is 'mul':
618        #
619        # self * other
620        #
621        other = as_numeric(args[1])
622        if type(args[0]) is _ProductExpression:
623            args[0].add(numerator=other)
624            return args[0]
625        else:
626            tmp = _ProductExpression()
627            tmp.add(numerator=args[0])
628            tmp.add(numerator=other)
629            return tmp
630
631    elif etype is 'rmul':
632        #
633        # other * self
634        #
635        other = as_numeric(args[1])
636        if type(args[0]) is _ProductExpression:
637            args[0].add(numerator=other)
638            return args[0]
639        else:
640            tmp = _ProductExpression()
641            tmp.add(numerator=other)
642            tmp.add(numerator=args[0])
643            return tmp
644
645    elif etype is 'div':
646        #
647        # self / other
648        #
649        other = as_numeric(args[1])
650        if type(args[0]) is _ProductExpression:
651            args[0].add(denominator=other)
652            return args[0]
653        else:
654            tmp = _ProductExpression()
655            tmp.add(numerator=args[0])
656            tmp.add(denominator=other)
657            #print "X",tmp.pprint()
658            return tmp
659
660    elif etype is 'rdiv':
661        #
662        # other / self
663        #
664        other = as_numeric(args[1])
665        if type(args[0]) is _ProductExpression:
666            args[0].invert()
667            args[0].add(numerator=other)
668            return args[0]
669        else:
670            tmp = _ProductExpression()
671            tmp.add(numerator=other)
672            tmp.add(denominator=args[0])
673            return tmp
674
675    elif etype is 'pow':
676        #
677        # self ** other
678        #
679        return _PowExpression([args[0], as_numeric(args[1])])
680
681    elif etype is 'rpow':
682        #
683        # other ** self
684        #
685        return _PowExpression([as_numeric(args[1]), args[0]])
686
687    elif etype is 'abs':
688        return _AbsExpression([args[0]])
689
690
691ExpressionRegistration('<', _LessThanExpression)
692ExpressionRegistration('lt', _LessThanExpression)
693ExpressionRegistration('>', _GreaterThanExpression)
694ExpressionRegistration('gt', _GreaterThanExpression)
695ExpressionRegistration('<=', _LessThanOrEqualExpression)
696ExpressionRegistration('lte', _LessThanOrEqualExpression)
697ExpressionRegistration('>=', _GreaterThanOrEqualExpression)
698ExpressionRegistration('gte', _GreaterThanOrEqualExpression)
699ExpressionRegistration('=', _EqualToExpression)
700ExpressionRegistration('eq', _EqualToExpression)
701
702if False:
703    ExpressionRegistration('+', _SumExpression)
704    ExpressionRegistration('sum', _SumExpression)
705    ExpressionRegistration('*', _ProductExpression)
706    ExpressionRegistration('prod', _ProductExpression)
707    #ExpressionRegistration('-', _MinusExpression)
708    #ExpressionRegistration('minus', _MinusExpression)
709    #ExpressionRegistration('/', _DivisionExpression)
710    #ExpressionRegistration('divide', _DivisionExpression)
711    #ExpressionRegistration('-', _NegateExpression)
712    #ExpressionRegistration('negate', _NegateExpression)
713    ExpressionRegistration('abs', _AbsExpression)
714    ExpressionRegistration('pow', _PowExpression)
715
Note: See TracBrowser for help on using the repository browser.