source: trunk/Cbc/examples/sudoku.cpp

Last change on this file was 1898, checked in by stefan, 5 years ago

fixup examples

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 14.4 KB
Line 
1// $Id: sudoku.cpp 1898 2013-04-09 18:06:04Z forrest $
2// Copyright (C) 2005, International Business Machines
3// Corporation and others.  All Rights Reserved.
4// This code is licensed under the terms of the Eclipse Public License (EPL).
5
6#include <cassert>
7#include <iomanip>
8
9
10#include "CoinPragma.hpp"
11// For Branch and bound
12#include "OsiSolverInterface.hpp"
13#include "CbcModel.hpp"
14#include "CoinModel.hpp"
15// For all different
16#include "CbcBranchCut.hpp"
17#include "CbcBranchActual.hpp"
18#include "CbcBranchAllDifferent.hpp"
19#include "CbcCutGenerator.hpp"
20#include "CglAllDifferent.hpp"
21#include "OsiClpSolverInterface.hpp"
22#include "CglStored.hpp"
23
24#include "CoinTime.hpp"
25
26
27/************************************************************************
28
29This shows how we can define a new branching method to solve problems with
30all different constraints.
31
32We are going to solve a sudoku problem such as
33
341, , ,4, , ,7, ,
35 ,2, , ,5, , ,8,
368,7,3, , ,6, , ,9
374, , ,7, , ,1, ,
38 ,5, , ,8, , ,2,
39 , ,6, ,4,9, , ,3
407, , ,1, , ,4, ,
41 ,8, ,6,2, , ,5,
42 , ,9, ,7,3, ,1,6
43
44The input should be exported from spreadsheet as a csv file where cells are
45empty unless value is known.
46
47We set up a fake objective and simple constraints to say sum must be 45
48
49and then we add all different branching (CbcBranchAllDifferent)
50and all different cuts (to fix variables) (CglAllDifferent).
51
52CbcBranchAllDifferent is really just an example of a cut branch.  If we wish to stop x==y
53then we can have two branches - one x <= y-1 and the other x >= y+1.  It should be easy for the user to
54make up similar cut branches for other uses.
55
56Note - this is all we need to solve most 9 x 9 puzzles because they seem to solve
57at root node or easily.  To solve 16 x 16 puzzles we need more.  All different cuts
58need general integer variables to be fixed while we can branch so they are just at bounds. 
59To get round that we can introduce extra 0-1 variables such that general integer x = sum j * delta j
60and then do N way branching on these (CbcNWay) so that we fix one delta j to 1.  At the same time we use the
61new class CbcConsequence (used in N way branching) which when delta j goes to 1 fixes other variables.
62So it will fix x to the correct value and while we are at it we can fix some delta variables in other
63sets to zero (as per all different rules).  Finally as well as storing the instructions which say if
64delta 11 is 1 then delta 21 is 0 we can also add this in as a cut using new trivial cut class
65CglStored.
66
67************************************************************************/
68
69int main (int argc, const char *argv[])
70{
71  // Get data
72  std::string fileName = "./sudoku_sample.csv";
73  if (argc>=2) fileName = argv[1];
74  FILE * fp = fopen(fileName.c_str(),"r");
75  if (!fp) {
76    printf("Unable to open file %s\n",fileName.c_str());
77    exit(0);
78  }
79#define MAX_SIZE 16
80  int valueOffset=1;
81  double lo[MAX_SIZE*MAX_SIZE],up[MAX_SIZE*MAX_SIZE];
82  char line[80];
83  int row,column;
84  /***************************************
85     Read .csv file and see if 9 or 16 Su Doku
86  ***************************************/
87  int size=9;
88  for (row=0;row<size;row++) {
89    fgets(line,80,fp);
90    // Get size of sudoku puzzle (9 or 16)
91    if (!row) {
92      int get=0;
93      size=1;
94      while (line[get]>=32) {
95        if (line[get]==',')
96          size++;
97        get++;
98      }
99      assert (size==9||size==16);
100      printf("Solving Su Doku of size %d\n",size);
101      if (size==16)
102        valueOffset=0;
103    }
104    int get=0;
105    for (column=0;column<size;column++) {
106      lo[size*row+column]=valueOffset;
107      up[size*row+column]=valueOffset-1+size;
108      if (line[get]!=','&&line[get]>=32) {
109        // skip blanks
110        if (line[get]==' ') {
111          get++;
112          continue;
113        }
114        int value = line[get]-'0';
115        if (size==9) {
116          assert (value>=1&&value<=9);
117        } else {
118          assert (size==16);
119          if (value<0||value>9) {
120            if (line[get]=='"') {
121              get++;
122              value = 10 + line[get]-'A';
123              if (value<10||value>15) {
124                value = 10 + line[get]-'a';
125              }
126              get++;
127            } else {
128              value = 10 + line[get]-'A';
129              if (value<10||value>15) {
130                value = 10 + line[get]-'a';
131              }
132          }
133          }
134          assert (value>=0&&value<=15);
135        }
136        lo[size*row+column]=value;
137        up[size*row+column]=value;
138        get++;
139      }
140      get++;
141    }
142  }
143  int block_size = (int) sqrt ((double) size);
144  /***************************************
145     Now build rules for all different
146     3*9 or 3*16 sets of variables
147     Number variables by row*size+column
148  ***************************************/
149  int starts[3*MAX_SIZE+1];
150  int which[3*MAX_SIZE*MAX_SIZE];
151  int put=0;
152  int set=0;
153  starts[0]=0;
154  // By row
155  for (row=0;row<size;row++) {
156    for (column=0;column<size;column++) 
157      which[put++]=row*size+column;
158    starts[set+1]=put;
159    set++;
160  }
161  // By column
162  for (column=0;column<size;column++) {
163    for (row=0;row<size;row++) 
164      which[put++]=row*size+column;
165    starts[set+1]=put;
166    set++;
167  }
168  // By box
169  for (row=0;row<size;row+=block_size) {
170    for (column=0;column<size;column+=block_size) {
171      for (int row2=row;row2<row+block_size;row2++) {
172        for (int column2=column;column2<column+block_size;column2++) 
173          which[put++]=row2*size+column2;
174      }
175      starts[set+1]=put;
176      set++;
177    }
178  }
179  OsiClpSolverInterface solver1;
180
181  /***************************************
182     Create model
183     Set variables to be general integer variables although
184     priorities probably mean that it won't matter
185  ***************************************/
186  CoinModel build;
187  // Columns
188  char name[4];
189  for (row=0;row<size;row++) {
190    for (column=0;column<size;column++) {
191      if (row<10) {
192        if (column<10) 
193          sprintf(name,"X%d%d",row,column);
194        else
195          sprintf(name,"X%d%c",row,'A'+(column-10));
196      } else {
197        if (column<10) 
198          sprintf(name,"X%c%d",'A'+(row-10),column);
199        else
200          sprintf(name,"X%c%c",'A'+(row-10),'A'+(column-10));
201      }
202      double value = CoinDrand48()*100.0;
203      build.addColumn(0,NULL,NULL,lo[size*row+column],
204                      up[size*row+column], value, name,true);
205    }
206  }
207  /***************************************
208     Now add in extra variables for N way branching
209  ***************************************/
210  for (row=0;row<size;row++) {
211    for (column=0;column<size;column++) {
212      int iColumn = size*row+column;
213      double value = lo[iColumn];
214      if (value<up[iColumn]) {
215        for (int i=0;i<size;i++)
216          build.addColumn(0,NULL,NULL,0.0,1.0,0.0);
217      } else {
218        // fixed
219        // obviously could do better  if we missed out variables
220        int which = ((int) value) - valueOffset;
221        for (int i=0;i<size;i++) {
222          if (i!=which)
223            build.addColumn(0,NULL,NULL,0.0,0.0,0.0);
224          else
225            build.addColumn(0,NULL,NULL,0.0,1.0,0.0);
226        }
227      }
228    }
229  }
230 
231  /***************************************
232     Now rows
233  ***************************************/
234  double values[]={1.0,1.0,1.0,1.0,
235                   1.0,1.0,1.0,1.0,
236                   1.0,1.0,1.0,1.0,
237                   1.0,1.0,1.0,1.0};
238  int indices[MAX_SIZE+1];
239  double rhs = size==9 ? 45.0 : 120.0;
240  for (row=0;row<3*size;row++) {
241    int iStart = starts[row];
242    for (column=0;column<size;column++) 
243      indices[column]=which[column+iStart];
244    build.addRow(size,indices,values,rhs,rhs);
245  }
246  double values2[MAX_SIZE+1];
247  values2[0]=-1.0;
248  for (row=0;row<size;row++) 
249    values2[row+1]=row+valueOffset;
250  // Now add rows for extra variables
251  for (row=0;row<size;row++) {
252    for (column=0;column<size;column++) {
253      int iColumn = row*size+column;
254      int base = size*size + iColumn*size;
255      indices[0]=iColumn;
256      for (int i=0;i<size;i++)
257        indices[i+1]=base+i;
258      build.addRow(size+1,indices,values2,0.0,0.0);
259    }
260  }
261  solver1.loadFromCoinModel(build);
262  build.writeMps("xx.mps");
263 
264  double time1 = CoinCpuTime();
265  solver1.initialSolve();
266  CbcModel model(solver1);
267  model.solver()->setHintParam(OsiDoReducePrint,true,OsiHintTry);
268  model.solver()->setHintParam(OsiDoScale,false,OsiHintTry);
269  /***************************************
270    Add in All different cut generator and All different branching
271    So we will have integers then cut branching then N way branching
272    in reverse priority order
273  ***************************************/
274 
275  // Cut generator
276  CglAllDifferent allDifferent(3*size,starts,which);
277  model.addCutGenerator(&allDifferent,-99,"allDifferent");
278  model.cutGenerator(0)->setWhatDepth(5);
279  CbcObject ** objects = new CbcObject * [4*size*size];
280  int nObj=0;
281  for (row=0;row<3*size;row++) {
282    int iStart = starts[row];
283    objects[row]= new CbcBranchAllDifferent(&model,size,which+iStart);
284    objects[row]->setPriority(2000+nObj); // do after rest satisfied
285    nObj++;
286  }
287  /***************************************
288     Add in N way branching and while we are at it add in cuts
289  ***************************************/
290  CglStored stored;
291  for (row=0;row<size;row++) {
292    for (column=0;column<size;column++) {
293      int iColumn = row*size+column;
294      int base = size*size + iColumn*size;
295      int i;
296      for ( i=0;i<size;i++)
297        indices[i]=base+i;
298      CbcNWay * obj = new CbcNWay(&model,size,indices,nObj);
299      int seq[200];
300      int newUpper[200];
301      memset(newUpper,0,sizeof(newUpper));
302      for (i=0;i<size;i++) {
303        int state=9999;
304        int one=1;
305        int nFix=1;
306        // Fix real variable
307        seq[0]=iColumn;
308        newUpper[0]=valueOffset+i;
309        int kColumn = base+i;
310        int j;
311        // same row
312        for (j=0;j<size;j++) {
313          int jColumn = row*size+j;
314          int jjColumn = size*size+jColumn*size+i;
315          if (jjColumn!=kColumn) {
316            seq[nFix++]=jjColumn; // must be zero
317          }
318        }
319        // same column
320        for (j=0;j<size;j++) {
321          int jColumn = j*size+column;
322          int jjColumn = size*size+jColumn*size+i;
323          if (jjColumn!=kColumn) {
324            seq[nFix++]=jjColumn; // must be zero
325          }
326        }
327        // same block
328        int kRow = row/block_size;
329        kRow *= block_size;
330        int kCol = column/block_size;
331        kCol *= block_size;
332        for (j=kRow;j<kRow+block_size;j++) {
333          for (int jc=kCol;jc<kCol+block_size;jc++) {
334            int jColumn = j*size+jc;
335            int jjColumn = size*size+jColumn*size+i;
336            if (jjColumn!=kColumn) {
337              seq[nFix++]=jjColumn; // must be zero
338            }
339          }
340        }
341        // seem to need following?
342        const int * upperAddress = newUpper;
343        const int * seqAddress = seq;
344        CbcFixVariable fix(1,&state,&one,&upperAddress,&seqAddress,&nFix,&upperAddress,&seqAddress);
345        obj->setConsequence(indices[i],fix);
346        // Now do as cuts
347        for (int kk=1;kk<nFix;kk++) {
348          int jColumn = seq[kk];
349          int cutInd[2];
350          cutInd[0]=kColumn;
351          if (jColumn>kColumn) {
352            cutInd[1]=jColumn;
353            stored.addCut(-COIN_DBL_MAX,1.0,2,cutInd,values);
354          }
355        }
356      }
357      objects[nObj]= obj;
358      objects[nObj]->setPriority(nObj);
359      nObj++;
360    }
361  }
362  model.addObjects(nObj,objects);
363  for (row=0;row<nObj;row++) 
364    delete objects[row];
365  delete [] objects;
366  model.messageHandler()->setLogLevel(1);
367  model.addCutGenerator(&stored,1,"stored");
368  // Say we want timings
369  int numberGenerators = model.numberCutGenerators();
370  int iGenerator;
371  for (iGenerator=0;iGenerator<numberGenerators;iGenerator++) {
372    CbcCutGenerator * generator = model.cutGenerator(iGenerator);
373    generator->setTiming(true);
374  }
375  // Set this to get all solutions (all ones in newspapers should only have one)
376  //model.setCutoffIncrement(-1.0e6);
377  /***************************************
378     Do branch and bound
379  ***************************************/
380  // Do complete search
381  model.branchAndBound();
382  std::cout<<"took "<<CoinCpuTime()-time1<<" seconds, "
383           <<model.getNodeCount()<<" nodes with objective "
384           <<model.getObjValue()
385           <<(!model.status() ? " Finished" : " Not finished")
386           <<std::endl;
387
388
389  /***************************************
390     Print solution and check it is feasible
391     We could modify output so could be imported by spreadsheet
392  ***************************************/
393  if (model.getMinimizationObjValue()<1.0e50) {
394   
395    const double * solution = model.bestSolution();
396    int put=0;
397    for (row=0;row<size;row++) {
398      for (column=0;column<size;column++) {
399        int value = (int) floor(solution[row*size+column]+0.5);
400        assert (value>=lo[put]&&value<=up[put]);
401        // save for later test
402        lo[put++]=value;
403        printf("%d ",value);
404      }
405      printf("\n");
406    }
407    // check valid
408    bool valid=true;
409    // By row
410    for (row=0;row<size;row++) {
411      put=0;
412      for (column=0;column<size;column++) 
413        which[put++]=row*size+column;
414      assert (put==size);
415      int i;
416      for (i=0;i<put;i++) 
417        which[i]=(int) lo[which[i]];
418      std::sort(which,which+put);
419      int last = valueOffset-1;
420      for (i=0;i<put;i++) {
421        int value=which[i];
422        if (value!=last+1)
423          valid=false;
424        last=value;
425      }
426    }
427    // By column
428    for (column=0;column<size;column++) {
429      put=0;
430      for (row=0;row<size;row++) 
431        which[put++]=row*size+column;
432      assert (put==size);
433      int i;
434      for (i=0;i<put;i++) 
435        which[i]=(int) lo[which[i]];
436      std::sort(which,which+put);
437      int last = valueOffset-1;
438      for (i=0;i<put;i++) {
439        int value=which[i];
440        if (value!=last+1)
441          valid=false;
442        last=value;
443      }
444    }
445    // By box
446    for (row=0;row<size;row+=block_size) {
447      for (column=0;column<size;column+=block_size) {
448        put=0;
449        for (int row2=row;row2<row+block_size;row2++) {
450          for (int column2=column;column2<column+block_size;column2++) 
451            which[put++]=row2*size+column2;
452        }
453        assert (put==size);
454        int i;
455        for (i=0;i<put;i++) 
456          which[i]=(int) lo[which[i]];
457        std::sort(which,which+put);
458        int last = valueOffset-1;
459        for (i=0;i<put;i++) {
460          int value=which[i];
461          if (value!=last+1)
462          valid=false;
463          last=value;
464        }
465      }
466    }
467    if (valid) {
468      printf("solution is valid\n");
469    } else {
470      printf("solution is not valid\n");
471      abort();
472    }
473  }
474  return 0;
475}   
Note: See TracBrowser for help on using the repository browser.