source: trunk/Cbc/examples/sudoku.cpp @ 1464

Last change on this file since 1464 was 1464, checked in by stefan, 9 years ago

merge split branch into trunk; fix some examples

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 14.4 KB
Line 
1// Copyright (C) 2005, International Business Machines
2// Corporation and others.  All Rights Reserved.
3#if defined(_MSC_VER)
4// Turn off compiler warning about long names
5#  pragma warning(disable:4786)
6#endif
7
8#include <cassert>
9#include <iomanip>
10
11
12// For Branch and bound
13#include "OsiSolverInterface.hpp"
14#include "CbcModel.hpp"
15#include "CoinModel.hpp"
16// For all different
17#include "CbcBranchCut.hpp"
18#include "CbcBranchActual.hpp"
19#include "CbcBranchAllDifferent.hpp"
20#include "CbcCutGenerator.hpp"
21#include "CglAllDifferent.hpp"
22#include "OsiClpSolverInterface.hpp"
23#include "CglStored.hpp"
24
25#include  "CoinTime.hpp"
26
27
28/************************************************************************
29
30This shows how we can define a new branching method to solve problems with
31all different constraints.
32
33We are going to solve a sudoku problem such as
34
351, , ,4, , ,7, ,
36 ,2, , ,5, , ,8,
378,7,3, , ,6, , ,9
384, , ,7, , ,1, ,
39 ,5, , ,8, , ,2,
40 , ,6, ,4,9, , ,3
417, , ,1, , ,4, ,
42 ,8, ,6,2, , ,5,
43 , ,9, ,7,3, ,1,6
44
45The input should be exported from spreadsheet as a csv file where cells are
46empty unless value is known.
47
48We set up a fake objective and simple constraints to say sum must be 45
49
50and then we add all different branching (CbcBranchAllDifferent)
51and all different cuts (to fix variables) (CglAllDifferent).
52
53CbcBranchAllDifferent is really just an example of a cut branch.  If we wish to stop x==y
54then we can have two branches - one x <= y-1 and the other x >= y+1.  It should be easy for the user to
55make up similar cut branches for other uses.
56
57Note - this is all we need to solve most 9 x 9 puzzles because they seem to solve
58at root node or easily.  To solve 16 x 16 puzzles we need more.  All different cuts
59need general integer variables to be fixed while we can branch so they are just at bounds. 
60To get round that we can introduce extra 0-1 variables such that general integer x = sum j * delta j
61and then do N way branching on these (CbcNWay) so that we fix one delta j to 1.  At the same time we use the
62new class CbcConsequence (used in N way branching) which when delta j goes to 1 fixes other variables.
63So it will fix x to the correct value and while we are at it we can fix some delta variables in other
64sets to zero (as per all different rules).  Finally as well as storing the instructions which say if
65delta 11 is 1 then delta 21 is 0 we can also add this in as a cut using new trivial cut class
66CglStored.
67
68************************************************************************/
69
70int main (int argc, const char *argv[])
71{
72  // Get data
73  std::string fileName = "./sudoku_sample.csv";
74  if (argc>=2) fileName = argv[1];
75  FILE * fp = fopen(fileName.c_str(),"r");
76  if (!fp) {
77    printf("Unable to open file %s\n",fileName.c_str());
78    exit(0);
79  }
80#define MAX_SIZE 16
81  int valueOffset=1;
82  double lo[MAX_SIZE*MAX_SIZE],up[MAX_SIZE*MAX_SIZE];
83  char line[80];
84  int row,column;
85  /***************************************
86     Read .csv file and see if 9 or 16 Su Doku
87  ***************************************/
88  int size=9;
89  for (row=0;row<size;row++) {
90    fgets(line,80,fp);
91    // Get size of sudoku puzzle (9 or 16)
92    if (!row) {
93      int get=0;
94      size=1;
95      while (line[get]>=32) {
96        if (line[get]==',')
97          size++;
98        get++;
99      }
100      assert (size==9||size==16);
101      printf("Solving Su Doku of size %d\n",size);
102      if (size==16)
103        valueOffset=0;
104    }
105    int get=0;
106    for (column=0;column<size;column++) {
107      lo[size*row+column]=valueOffset;
108      up[size*row+column]=valueOffset-1+size;
109      if (line[get]!=','&&line[get]>=32) {
110        // skip blanks
111        if (line[get]==' ') {
112          get++;
113          continue;
114        }
115        int value = line[get]-'0';
116        if (size==9) {
117          assert (value>=1&&value<=9);
118        } else {
119          assert (size==16);
120          if (value<0||value>9) {
121            if (line[get]=='"') {
122              get++;
123              value = 10 + line[get]-'A';
124              if (value<10||value>15) {
125                value = 10 + line[get]-'a';
126              }
127              get++;
128            } else {
129              value = 10 + line[get]-'A';
130              if (value<10||value>15) {
131                value = 10 + line[get]-'a';
132              }
133          }
134          }
135          assert (value>=0&&value<=15);
136        }
137        lo[size*row+column]=value;
138        up[size*row+column]=value;
139        get++;
140      }
141      get++;
142    }
143  }
144  int block_size = (int) sqrt ((double) size);
145  /***************************************
146     Now build rules for all different
147     3*9 or 3*16 sets of variables
148     Number variables by row*size+column
149  ***************************************/
150  int starts[3*MAX_SIZE+1];
151  int which[3*MAX_SIZE*MAX_SIZE];
152  int put=0;
153  int set=0;
154  starts[0]=0;
155  // By row
156  for (row=0;row<size;row++) {
157    for (column=0;column<size;column++) 
158      which[put++]=row*size+column;
159    starts[set+1]=put;
160    set++;
161  }
162  // By column
163  for (column=0;column<size;column++) {
164    for (row=0;row<size;row++) 
165      which[put++]=row*size+column;
166    starts[set+1]=put;
167    set++;
168  }
169  // By box
170  for (row=0;row<size;row+=block_size) {
171    for (column=0;column<size;column+=block_size) {
172      for (int row2=row;row2<row+block_size;row2++) {
173        for (int column2=column;column2<column+block_size;column2++) 
174          which[put++]=row2*size+column2;
175      }
176      starts[set+1]=put;
177      set++;
178    }
179  }
180  OsiClpSolverInterface solver1;
181
182  /***************************************
183     Create model
184     Set variables to be general integer variables although
185     priorities probably mean that it won't matter
186  ***************************************/
187  CoinModel build;
188  // Columns
189  char name[4];
190  for (row=0;row<size;row++) {
191    for (column=0;column<size;column++) {
192      if (row<10) {
193        if (column<10) 
194          sprintf(name,"X%d%d",row,column);
195        else
196          sprintf(name,"X%d%c",row,'A'+(column-10));
197      } else {
198        if (column<10) 
199          sprintf(name,"X%c%d",'A'+(row-10),column);
200        else
201          sprintf(name,"X%c%c",'A'+(row-10),'A'+(column-10));
202      }
203      double value = CoinDrand48()*100.0;
204      build.addColumn(0,NULL,NULL,lo[size*row+column],
205                      up[size*row+column], value, name,true);
206    }
207  }
208  /***************************************
209     Now add in extra variables for N way branching
210  ***************************************/
211  for (row=0;row<size;row++) {
212    for (column=0;column<size;column++) {
213      int iColumn = size*row+column;
214      double value = lo[iColumn];
215      if (value<up[iColumn]) {
216        for (int i=0;i<size;i++)
217          build.addColumn(0,NULL,NULL,0.0,1.0,0.0);
218      } else {
219        // fixed
220        // obviously could do better  if we missed out variables
221        int which = ((int) value) - valueOffset;
222        for (int i=0;i<size;i++) {
223          if (i!=which)
224            build.addColumn(0,NULL,NULL,0.0,0.0,0.0);
225          else
226            build.addColumn(0,NULL,NULL,0.0,1.0,0.0);
227        }
228      }
229    }
230  }
231 
232  /***************************************
233     Now rows
234  ***************************************/
235  double values[]={1.0,1.0,1.0,1.0,
236                   1.0,1.0,1.0,1.0,
237                   1.0,1.0,1.0,1.0,
238                   1.0,1.0,1.0,1.0};
239  int indices[MAX_SIZE+1];
240  double rhs = size==9 ? 45.0 : 120.0;
241  for (row=0;row<3*size;row++) {
242    int iStart = starts[row];
243    for (column=0;column<size;column++) 
244      indices[column]=which[column+iStart];
245    build.addRow(size,indices,values,rhs,rhs);
246  }
247  double values2[MAX_SIZE+1];
248  values2[0]=-1.0;
249  for (row=0;row<size;row++) 
250    values2[row+1]=row+valueOffset;
251  // Now add rows for extra variables
252  for (row=0;row<size;row++) {
253    for (column=0;column<size;column++) {
254      int iColumn = row*size+column;
255      int base = size*size + iColumn*size;
256      indices[0]=iColumn;
257      for (int i=0;i<size;i++)
258        indices[i+1]=base+i;
259      build.addRow(size+1,indices,values2,0.0,0.0);
260    }
261  }
262  solver1.loadFromCoinModel(build);
263  build.writeMps("xx.mps");
264 
265  double time1 = CoinCpuTime();
266  solver1.initialSolve();
267  CbcModel model(solver1);
268  model.solver()->setHintParam(OsiDoReducePrint,true,OsiHintTry);
269  model.solver()->setHintParam(OsiDoScale,false,OsiHintTry);
270  /***************************************
271    Add in All different cut generator and All different branching
272    So we will have integers then cut branching then N way branching
273    in reverse priority order
274  ***************************************/
275 
276  // Cut generator
277  CglAllDifferent allDifferent(3*size,starts,which);
278  model.addCutGenerator(&allDifferent,-99,"allDifferent");
279  model.cutGenerator(0)->setWhatDepth(5);
280  CbcObject ** objects = new CbcObject * [4*size*size];
281  int nObj=0;
282  for (row=0;row<3*size;row++) {
283    int iStart = starts[row];
284    objects[row]= new CbcBranchAllDifferent(&model,size,which+iStart);
285    objects[row]->setPriority(2000+nObj); // do after rest satisfied
286    nObj++;
287  }
288  /***************************************
289     Add in N way branching and while we are at it add in cuts
290  ***************************************/
291  CglStored stored;
292  for (row=0;row<size;row++) {
293    for (column=0;column<size;column++) {
294      int iColumn = row*size+column;
295      int base = size*size + iColumn*size;
296      int i;
297      for ( i=0;i<size;i++)
298        indices[i]=base+i;
299      CbcNWay * obj = new CbcNWay(&model,size,indices,nObj);
300      int seq[200];
301      int newUpper[200];
302      memset(newUpper,0,sizeof(newUpper));
303      for (i=0;i<size;i++) {
304        int state=9999;
305        int one=1;
306        int nFix=1;
307        // Fix real variable
308        seq[0]=iColumn;
309        newUpper[0]=valueOffset+i;
310        int kColumn = base+i;
311        int j;
312        // same row
313        for (j=0;j<size;j++) {
314          int jColumn = row*size+j;
315          int jjColumn = size*size+jColumn*size+i;
316          if (jjColumn!=kColumn) {
317            seq[nFix++]=jjColumn; // must be zero
318          }
319        }
320        // same column
321        for (j=0;j<size;j++) {
322          int jColumn = j*size+column;
323          int jjColumn = size*size+jColumn*size+i;
324          if (jjColumn!=kColumn) {
325            seq[nFix++]=jjColumn; // must be zero
326          }
327        }
328        // same block
329        int kRow = row/block_size;
330        kRow *= block_size;
331        int kCol = column/block_size;
332        kCol *= block_size;
333        for (j=kRow;j<kRow+block_size;j++) {
334          for (int jc=kCol;jc<kCol+block_size;jc++) {
335            int jColumn = j*size+jc;
336            int jjColumn = size*size+jColumn*size+i;
337            if (jjColumn!=kColumn) {
338              seq[nFix++]=jjColumn; // must be zero
339            }
340          }
341        }
342        // seem to need following?
343        const int * upperAddress = newUpper;
344        const int * seqAddress = seq;
345        CbcFixVariable fix(1,&state,&one,&upperAddress,&seqAddress,&nFix,&upperAddress,&seqAddress);
346        obj->setConsequence(indices[i],fix);
347        // Now do as cuts
348        for (int kk=1;kk<nFix;kk++) {
349          int jColumn = seq[kk];
350          int cutInd[2];
351          cutInd[0]=kColumn;
352          if (jColumn>kColumn) {
353            cutInd[1]=jColumn;
354            stored.addCut(-COIN_DBL_MAX,1.0,2,cutInd,values);
355          }
356        }
357      }
358      objects[nObj]= obj;
359      objects[nObj]->setPriority(nObj);
360      nObj++;
361    }
362  }
363  model.addObjects(nObj,objects);
364  for (row=0;row<nObj;row++) 
365    delete objects[row];
366  delete [] objects;
367  model.messageHandler()->setLogLevel(1);
368  model.addCutGenerator(&stored,1,"stored");
369  // Say we want timings
370  int numberGenerators = model.numberCutGenerators();
371  int iGenerator;
372  for (iGenerator=0;iGenerator<numberGenerators;iGenerator++) {
373    CbcCutGenerator * generator = model.cutGenerator(iGenerator);
374    generator->setTiming(true);
375  }
376  // Set this to get all solutions (all ones in newspapers should only have one)
377  //model.setCutoffIncrement(-1.0e6);
378  /***************************************
379     Do branch and bound
380  ***************************************/
381  // Do complete search
382  model.branchAndBound();
383  std::cout<<"took "<<CoinCpuTime()-time1<<" seconds, "
384           <<model.getNodeCount()<<" nodes with objective "
385           <<model.getObjValue()
386           <<(!model.status() ? " Finished" : " Not finished")
387           <<std::endl;
388
389
390  /***************************************
391     Print solution and check it is feasible
392     We could modify output so could be imported by spreadsheet
393  ***************************************/
394  if (model.getMinimizationObjValue()<1.0e50) {
395   
396    const double * solution = model.bestSolution();
397    int put=0;
398    for (row=0;row<size;row++) {
399      for (column=0;column<size;column++) {
400        int value = (int) floor(solution[row*size+column]+0.5);
401        assert (value>=lo[put]&&value<=up[put]);
402        // save for later test
403        lo[put++]=value;
404        printf("%d ",value);
405      }
406      printf("\n");
407    }
408    // check valid
409    bool valid=true;
410    // By row
411    for (row=0;row<size;row++) {
412      put=0;
413      for (column=0;column<size;column++) 
414        which[put++]=row*size+column;
415      assert (put==size);
416      int i;
417      for (i=0;i<put;i++) 
418        which[i]=(int) lo[which[i]];
419      std::sort(which,which+put);
420      int last = valueOffset-1;
421      for (i=0;i<put;i++) {
422        int value=which[i];
423        if (value!=last+1)
424          valid=false;
425        last=value;
426      }
427    }
428    // By column
429    for (column=0;column<size;column++) {
430      put=0;
431      for (row=0;row<size;row++) 
432        which[put++]=row*size+column;
433      assert (put==size);
434      int i;
435      for (i=0;i<put;i++) 
436        which[i]=(int) lo[which[i]];
437      std::sort(which,which+put);
438      int last = valueOffset-1;
439      for (i=0;i<put;i++) {
440        int value=which[i];
441        if (value!=last+1)
442          valid=false;
443        last=value;
444      }
445    }
446    // By box
447    for (row=0;row<size;row+=block_size) {
448      for (column=0;column<size;column+=block_size) {
449        put=0;
450        for (int row2=row;row2<row+block_size;row2++) {
451          for (int column2=column;column2<column+block_size;column2++) 
452            which[put++]=row2*size+column2;
453        }
454        assert (put==size);
455        int i;
456        for (i=0;i<put;i++) 
457          which[i]=(int) lo[which[i]];
458        std::sort(which,which+put);
459        int last = valueOffset-1;
460        for (i=0;i<put;i++) {
461          int value=which[i];
462          if (value!=last+1)
463          valid=false;
464          last=value;
465        }
466      }
467    }
468    if (valid) {
469      printf("solution is valid\n");
470    } else {
471      printf("solution is not valid\n");
472      abort();
473    }
474  }
475  return 0;
476}   
Note: See TracBrowser for help on using the repository browser.