source: trunk/test/core/colin/test_parallel.py @ 1794

Last change on this file since 1794 was 1794, checked in by wehart, 10 years ago

Misc bug fixes that were introduced by the introduction of variable_map
data, which is now called symbol_map.

Note: some tests still fail, due to the fact that pico_convert does not
generate symbol mapping information. This is being resolved.

File size: 12.8 KB
Line 
1#
2# Unit Tests for coopr.opt.parallel (using the COLIN optimizers)
3#
4#
5
6import os
7import sys
8from os.path import abspath, dirname
9cooprdir = dirname(dirname(dirname(dirname(abspath(__file__)))))
10sys.path.insert(0, cooprdir)
11cooprdir += os.sep
12currdir = dirname(abspath(__file__))+os.sep
13
14import unittest
15from nose.tools import nottest
16import xml
17import coopr.opt
18from coopr.opt import ResultsFormat, ProblemFormat
19import pyutilib.th
20import pyutilib.plugin.core
21import pyutilib.services
22
23
24class TestProblem1(coopr.opt.colin.MixedIntOptProblem):
25
26    def __init__(self):
27        coopr.opt.colin.MixedIntOptProblem.__init__(self)
28        self.real_lower=[0.0, -1.0, 1.0, None]
29        self.real_upper=[None, 0.0, 2.0, -1.0]
30        self.nreal=4
31
32    def function_value(self, point):
33        self.validate(point)
34        return point.reals[0] - point.reals[1] + (point.reals[2]-1.5)**2 + (point.reals[3]+2)**4
35
36
37class TestSolverManager(coopr.opt.parallel.AsynchronousSolverManager):
38
39    def __init__(self, **kwds):
40        kwds['type'] = 'smtest_type'
41        kwds['doc'] = 'TestASM Documentation'
42        coopr.opt.parallel.AsynchronousSolverManager.__init__(self,**kwds)
43
44    def enabled(self):
45        return False
46
47
48class SolverManager_DelayedSerial(coopr.opt.parallel.AsynchronousSolverManager):
49
50    def clear(self):
51        """
52        Clear manager state
53        """
54        coopr.opt.parallel.AsynchronousSolverManager.clear(self)
55        self.delay=5
56        self._ah_list = []
57        self._opt = None
58        self._my_results = {}
59        self._ctr = 1
60        self._force_error = 0
61
62    def _perform_queue(self, ah, *args, **kwds):
63        """
64        Perform the queue operation.  This method returns the ActionHandle,
65        and the ActionHandle status indicates whether the queue was successful.
66        """
67        if 'opt' in kwds:
68            self._opt = kwds['opt']
69            del kwds['opt']
70        if self._opt is None:
71            raise ActionManagerError, "Undefined solver"
72        self._my_results[ah.id] = self._opt.solve(*args)
73        self._ah_list.append(ah)
74        return ah
75
76    def _perform_wait_any(self):
77        """
78        Perform the wait_any operation.  This method returns an
79        ActionHandle with the results of waiting.  If None is returned
80        then the ActionManager assumes that it can call this method again.
81        Note that an ActionHandle can be returned with a dummy value,
82        to indicate an error.
83        """
84        if self._force_error == 0:
85            self._ctr += 1
86            if self._ctr % self.delay != 0:
87                return None
88            if len(self._ah_list) > 0:
89                ah = self._ah_list.pop()
90                ah.status = coopr.opt.parallel.manager.ActionStatus.done
91                self.results[ah.id] = self._my_results[ah.id]
92                return ah
93            return coopr.opt.parallel.manager.ActionHandle(error=True, explanation="No queued evaluations available in the 'local' solver manager, which only executes solvers synchronously") 
94        elif self._force_error == 1:
95            #
96            # Wait Any returns an ActionHandle that indicates an error
97            #
98            return coopr.opt.parallel.manager.ActionHandle(error=True, explanation="Forced failure")
99        elif self._force_error == 2:
100            #
101            # Wait Any returns the correct ActionHandle, but no results are
102            # available.
103            #
104            return self._ah_list.pop()
105
106
107class Test(pyutilib.th.TestCase):
108
109
110    def run(self, result=None):
111        self.smtest_plugin = coopr.opt.SolverManagerRegistration("smtest", TestSolverManager)
112        unittest.TestCase.run(self,result)
113        self.smtest_plugin.deactivate()
114
115    def setUp(self):
116        self.do_setup(False)
117        pyutilib.services.TempfileManager.tempdir = currdir
118
119    def do_setup(self,flag):
120        pyutilib.services.TempfileManager.tempdir = currdir
121        self.ps = coopr.opt.colin.PatternSearch()
122
123    def tearDown(self):
124        pyutilib.services.TempfileManager.clear_tempfiles()
125
126    def test_solve1(self):
127        """ Test PatternSearch - TestProblem1 """
128        problem=TestProblem1()
129        self.ps.problem=problem
130        self.ps.initial_point = [1.0, -0.5, 2.0, -1.0]
131        self.ps.reset()
132        results = self.ps.solve(logfile=currdir+"test_solve1.log")
133        results.write(currdir+"test_solve1.txt",times=False)
134        self.failUnlessFileEqualsBaseline(currdir+"test_solve1.txt", currdir+"test1_ps.txt")
135        if os.path.exists(currdir+"test_solve1.log"):
136            os.remove(currdir+"test_solve1.log")
137
138    def test_serial1(self):
139        """ Test Serial EvalManager - TestProblem1 """
140        problem=TestProblem1()
141        self.ps.problem=problem
142        self.ps.initial_point = [1.0, -0.5, 2.0, -1.0]
143        self.ps.reset()
144        mngr = coopr.opt.parallel.SolverManagerFactory("serial")
145        results = mngr.solve(opt=self.ps, logfile=currdir+"test_solve2.log")
146        results.write(currdir+"test_solve2.txt",times=False)
147        self.failUnlessFileEqualsBaseline(currdir+"test_solve2.txt", currdir+"test1_ps.txt")
148        if os.path.exists(currdir+"test_solve2.log"):
149            os.remove(currdir+"test_solve2.log")
150
151    def test_serial_error1(self):
152        """ Test Serial SolverManager - Error with no optimizer"""
153        problem=TestProblem1()
154        self.ps.problem=problem
155        self.ps.initial_point = [1.0, -0.5, 2.0, -1.0]
156        self.ps.reset()
157        mngr = coopr.opt.parallel.SolverManagerFactory("serial")
158        try:
159            results = mngr.solve(logfile=currdir+"test_solve3.log")
160            self.fail("Expected error")
161        except coopr.opt.parallel.manager.ActionManagerError:
162            pass
163
164    def test_serial_error2(self):
165        """ Test Serial SolverManager - Error with no queue solves"""
166        problem=TestProblem1()
167        self.ps.problem=problem
168        self.ps.initial_point = [1.0, -0.5, 2.0, -1.0]
169        self.ps.reset()
170        mngr = coopr.opt.parallel.SolverManagerFactory("serial")
171        results = mngr.solve(opt=self.ps, logfile=currdir+"test_solve3.log")
172        if mngr.wait_any() != coopr.opt.parallel.manager.FailedActionHandle:
173            self.fail("Expected a failed action")
174        if os.path.exists(currdir+"test_solve2.log"):
175            os.remove(currdir+"test_solve2.log")
176
177    def test_solver_manager_factory(self):
178        """
179        Testing the coopr.opt solver factory
180        """
181        ans = coopr.opt.SolverManagerFactory()
182        ans.sort()
183        tmp = ["smtest"]
184        tmp.sort()
185        self.failUnless(set(tmp) <= set(ans))
186
187    def test_solver_manager_instance(self):
188        """
189        Testing that we get a specific solver instance
190        """
191        ans = coopr.opt.SolverManagerFactory("none")
192        self.failUnlessEqual(ans, None)
193        ans = coopr.opt.SolverManagerFactory("smtest")
194        self.failUnlessEqual(type(ans), TestSolverManager)
195        ans = coopr.opt.SolverManagerFactory("smtest", "mymock")
196        self.failUnlessEqual(type(ans), TestSolverManager)
197        self.failUnlessEqual(ans.name,  "mymock")
198
199    def test_solver_manager_registration(self):
200        """
201        Testing methods in the solverwriter factory registration process
202        """
203        ep = pyutilib.plugin.core.ExtensionPoint(coopr.opt.parallel.solver.ISolverManagerRegistration)
204        service = ep.service("smtest")
205        self.failUnlessEqual(service.type(), "smtest")
206
207    def test_delayed_serial1(self):
208        """
209        Use a solver manager that delays the evaluation of responses,
210        and thus allows a mock testing of the wait*() methods.
211        """
212        problem=TestProblem1()
213        self.ps.problem=problem
214        self.ps.initial_point = [1.0, -0.5, 2.0, -1.0]
215        self.ps.reset()
216        mngr = SolverManager_DelayedSerial()
217        results = mngr.solve(opt=self.ps, logfile=currdir+"test_solve4.log")
218        results.write(currdir+"test_solve4.txt",times=False)
219        self.failUnlessFileEqualsBaseline(currdir+"test_solve4.txt", currdir+
220"test1_ps.txt")
221        if os.path.exists(currdir+"test_solve4.log"):
222            os.remove(currdir+"test_solve4.log")
223
224    def test_delayed_serial2(self):
225        """
226        Use a solver manager that delays the evaluation of responses,
227        and _perform_wait_any() returns a failed action handle.
228        """
229        problem=TestProblem1()
230        self.ps.problem=problem
231        self.ps.initial_point = [1.0, -0.5, 2.0, -1.0]
232        self.ps.reset()
233        mngr = SolverManager_DelayedSerial()
234        mngr._force_error = 1
235        try:
236            results = mngr.solve(opt=self.ps, logfile=currdir+"test_solve5.log")
237            self.fail("Expected error")
238        except coopr.opt.parallel.manager.ActionManagerError:
239            pass
240        if os.path.exists(currdir+"test_solve5.log"):
241            os.remove(currdir+"test_solve5.log")
242
243    def test_delayed_serial3(self):
244        """
245        Use a solver manager that delays the evaluation of responses,
246        and _perform_wait_any() returns a action handle, but no results are available.
247        """
248        problem=TestProblem1()
249        self.ps.problem=problem
250        self.ps.initial_point = [1.0, -0.5, 2.0, -1.0]
251        self.ps.reset()
252        mngr = SolverManager_DelayedSerial()
253        mngr._force_error = 2
254        try:
255            results = mngr.solve(opt=self.ps, logfile=currdir+"test_solve6.log")
256            self.fail("Expected error")
257        except coopr.opt.parallel.manager.ActionManagerError:
258            pass
259        if os.path.exists(currdir+"test_solve6.log"):
260            os.remove(currdir+"test_solve6.log")
261
262    def test_delayed_serial4(self):
263        """
264        Use a solver manager that delays the evaluation of responses,
265        and verify that queue-ing multiple solves works.
266        """
267        problem=TestProblem1()
268        self.ps.problem=problem
269        self.ps.initial_point = [1.0, -0.5, 2.0, -1.0]
270        self.ps.reset()
271        mngr = SolverManager_DelayedSerial()
272        ah_a = mngr.queue(opt=self.ps, logfile=currdir+"test_solve7a.log")
273        ah_b = mngr.queue(opt=self.ps, logfile=currdir+"test_solve7b.log")
274        ah_c = mngr.queue(opt=self.ps, logfile=currdir+"test_solve7c.log")
275
276        mngr.wait_all()
277
278        self.failUnlessEqual(ah_c.status, coopr.opt.parallel.manager.ActionStatus.done)
279        if os.path.exists(currdir+"test_solve7a.log"):
280            os.remove(currdir+"test_solve7a.log")
281        if os.path.exists(currdir+"test_solve7b.log"):
282            os.remove(currdir+"test_solve7b.log")
283        if os.path.exists(currdir+"test_solve7c.log"):
284            os.remove(currdir+"test_solve7c.log")
285
286    def test_delayed_serial5(self):
287        """
288        Use a solver manager that delays the evaluation of responses,
289        and verify that queue-ing multiple solves works.
290        """
291        problem=TestProblem1()
292        self.ps.problem=problem
293        self.ps.initial_point = [1.0, -0.5, 2.0, -1.0]
294        self.ps.reset()
295        mngr = SolverManager_DelayedSerial()
296        ah_a = mngr.queue(opt=self.ps, logfile=currdir+"test_solve8a.log")
297        ah_b = mngr.queue(opt=self.ps, logfile=currdir+"test_solve8b.log")
298        ah_c = mngr.queue(opt=self.ps, logfile=currdir+"test_solve8c.log")
299
300        mngr.wait_all(ah_b)
301
302        self.failUnlessEqual(ah_b.status, coopr.opt.parallel.manager.ActionStatus.done)
303        self.failUnlessEqual(ah_a.status, coopr.opt.parallel.manager.ActionStatus.queued)
304        if os.path.exists(currdir+"test_solve8a.log"):
305            os.remove(currdir+"test_solve8a.log")
306        if os.path.exists(currdir+"test_solve8b.log"):
307            os.remove(currdir+"test_solve8b.log")
308        if os.path.exists(currdir+"test_solve8c.log"):
309            os.remove(currdir+"test_solve8c.log")
310
311    def test_delayed_serial6(self):
312        """
313        Use a solver manager that delays the evaluation of responses,
314        and verify that queue-ing multiple solves works.
315        """
316        problem=TestProblem1()
317        self.ps.problem=problem
318        self.ps.initial_point = [1.0, -0.5, 2.0, -1.0]
319        self.ps.reset()
320        mngr = SolverManager_DelayedSerial()
321        ah_a = mngr.queue(opt=self.ps, logfile=currdir+"test_solve8a.log")
322        ah_b = mngr.queue(opt=self.ps, logfile=currdir+"test_solve8b.log")
323        ah_c = mngr.queue(opt=self.ps, logfile=currdir+"test_solve8c.log")
324
325        self.failUnlessEqual( mngr.num_queued(), 3)
326        mngr.wait_all( [ah_b] )
327
328        self.failUnlessEqual(mngr.get_status(ah_b), coopr.opt.parallel.manager.ActionStatus.done)
329        self.failUnlessEqual(mngr.get_status(ah_a), coopr.opt.parallel.manager.ActionStatus.queued)
330
331        if os.path.exists(currdir+"test_solve8a.log"):
332            os.remove(currdir+"test_solve8a.log")
333        if os.path.exists(currdir+"test_solve8b.log"):
334            os.remove(currdir+"test_solve8b.log")
335        if os.path.exists(currdir+"test_solve8c.log"):
336            os.remove(currdir+"test_solve8c.log")
337
338if __name__ == "__main__":
339    unittest.main()
340
Note: See TracBrowser for help on using the repository browser.