source: trunk/Cbc/examples/sudoku.cpp @ 1574

Last change on this file since 1574 was 1574, checked in by lou, 8 years ago

Change to EPL license notice.

  • Property svn:eol-style set to native
  • Property svn:keywords set to Author Date Id Revision
File size: 14.5 KB
Line 
1// $Id: sudoku.cpp 1574 2011-01-05 01:13:55Z lou $
2// Copyright (C) 2005, International Business Machines
3// Corporation and others.  All Rights Reserved.
4// This code is licensed under the terms of the Eclipse Public License (EPL).
5
6#if defined(_MSC_VER)
7// Turn off compiler warning about long names
8#  pragma warning(disable:4786)
9#endif
10
11#include <cassert>
12#include <iomanip>
13
14
15// For Branch and bound
16#include "OsiSolverInterface.hpp"
17#include "CbcModel.hpp"
18#include "CoinModel.hpp"
19// For all different
20#include "CbcBranchCut.hpp"
21#include "CbcBranchActual.hpp"
22#include "CbcBranchAllDifferent.hpp"
23#include "CbcCutGenerator.hpp"
24#include "CglAllDifferent.hpp"
25#include "OsiClpSolverInterface.hpp"
26#include "CglStored.hpp"
27
28#include  "CoinTime.hpp"
29
30
31/************************************************************************
32
33This shows how we can define a new branching method to solve problems with
34all different constraints.
35
36We are going to solve a sudoku problem such as
37
381, , ,4, , ,7, ,
39 ,2, , ,5, , ,8,
408,7,3, , ,6, , ,9
414, , ,7, , ,1, ,
42 ,5, , ,8, , ,2,
43 , ,6, ,4,9, , ,3
447, , ,1, , ,4, ,
45 ,8, ,6,2, , ,5,
46 , ,9, ,7,3, ,1,6
47
48The input should be exported from spreadsheet as a csv file where cells are
49empty unless value is known.
50
51We set up a fake objective and simple constraints to say sum must be 45
52
53and then we add all different branching (CbcBranchAllDifferent)
54and all different cuts (to fix variables) (CglAllDifferent).
55
56CbcBranchAllDifferent is really just an example of a cut branch.  If we wish to stop x==y
57then we can have two branches - one x <= y-1 and the other x >= y+1.  It should be easy for the user to
58make up similar cut branches for other uses.
59
60Note - this is all we need to solve most 9 x 9 puzzles because they seem to solve
61at root node or easily.  To solve 16 x 16 puzzles we need more.  All different cuts
62need general integer variables to be fixed while we can branch so they are just at bounds. 
63To get round that we can introduce extra 0-1 variables such that general integer x = sum j * delta j
64and then do N way branching on these (CbcNWay) so that we fix one delta j to 1.  At the same time we use the
65new class CbcConsequence (used in N way branching) which when delta j goes to 1 fixes other variables.
66So it will fix x to the correct value and while we are at it we can fix some delta variables in other
67sets to zero (as per all different rules).  Finally as well as storing the instructions which say if
68delta 11 is 1 then delta 21 is 0 we can also add this in as a cut using new trivial cut class
69CglStored.
70
71************************************************************************/
72
73int main (int argc, const char *argv[])
74{
75  // Get data
76  std::string fileName = "./sudoku_sample.csv";
77  if (argc>=2) fileName = argv[1];
78  FILE * fp = fopen(fileName.c_str(),"r");
79  if (!fp) {
80    printf("Unable to open file %s\n",fileName.c_str());
81    exit(0);
82  }
83#define MAX_SIZE 16
84  int valueOffset=1;
85  double lo[MAX_SIZE*MAX_SIZE],up[MAX_SIZE*MAX_SIZE];
86  char line[80];
87  int row,column;
88  /***************************************
89     Read .csv file and see if 9 or 16 Su Doku
90  ***************************************/
91  int size=9;
92  for (row=0;row<size;row++) {
93    fgets(line,80,fp);
94    // Get size of sudoku puzzle (9 or 16)
95    if (!row) {
96      int get=0;
97      size=1;
98      while (line[get]>=32) {
99        if (line[get]==',')
100          size++;
101        get++;
102      }
103      assert (size==9||size==16);
104      printf("Solving Su Doku of size %d\n",size);
105      if (size==16)
106        valueOffset=0;
107    }
108    int get=0;
109    for (column=0;column<size;column++) {
110      lo[size*row+column]=valueOffset;
111      up[size*row+column]=valueOffset-1+size;
112      if (line[get]!=','&&line[get]>=32) {
113        // skip blanks
114        if (line[get]==' ') {
115          get++;
116          continue;
117        }
118        int value = line[get]-'0';
119        if (size==9) {
120          assert (value>=1&&value<=9);
121        } else {
122          assert (size==16);
123          if (value<0||value>9) {
124            if (line[get]=='"') {
125              get++;
126              value = 10 + line[get]-'A';
127              if (value<10||value>15) {
128                value = 10 + line[get]-'a';
129              }
130              get++;
131            } else {
132              value = 10 + line[get]-'A';
133              if (value<10||value>15) {
134                value = 10 + line[get]-'a';
135              }
136          }
137          }
138          assert (value>=0&&value<=15);
139        }
140        lo[size*row+column]=value;
141        up[size*row+column]=value;
142        get++;
143      }
144      get++;
145    }
146  }
147  int block_size = (int) sqrt ((double) size);
148  /***************************************
149     Now build rules for all different
150     3*9 or 3*16 sets of variables
151     Number variables by row*size+column
152  ***************************************/
153  int starts[3*MAX_SIZE+1];
154  int which[3*MAX_SIZE*MAX_SIZE];
155  int put=0;
156  int set=0;
157  starts[0]=0;
158  // By row
159  for (row=0;row<size;row++) {
160    for (column=0;column<size;column++) 
161      which[put++]=row*size+column;
162    starts[set+1]=put;
163    set++;
164  }
165  // By column
166  for (column=0;column<size;column++) {
167    for (row=0;row<size;row++) 
168      which[put++]=row*size+column;
169    starts[set+1]=put;
170    set++;
171  }
172  // By box
173  for (row=0;row<size;row+=block_size) {
174    for (column=0;column<size;column+=block_size) {
175      for (int row2=row;row2<row+block_size;row2++) {
176        for (int column2=column;column2<column+block_size;column2++) 
177          which[put++]=row2*size+column2;
178      }
179      starts[set+1]=put;
180      set++;
181    }
182  }
183  OsiClpSolverInterface solver1;
184
185  /***************************************
186     Create model
187     Set variables to be general integer variables although
188     priorities probably mean that it won't matter
189  ***************************************/
190  CoinModel build;
191  // Columns
192  char name[4];
193  for (row=0;row<size;row++) {
194    for (column=0;column<size;column++) {
195      if (row<10) {
196        if (column<10) 
197          sprintf(name,"X%d%d",row,column);
198        else
199          sprintf(name,"X%d%c",row,'A'+(column-10));
200      } else {
201        if (column<10) 
202          sprintf(name,"X%c%d",'A'+(row-10),column);
203        else
204          sprintf(name,"X%c%c",'A'+(row-10),'A'+(column-10));
205      }
206      double value = CoinDrand48()*100.0;
207      build.addColumn(0,NULL,NULL,lo[size*row+column],
208                      up[size*row+column], value, name,true);
209    }
210  }
211  /***************************************
212     Now add in extra variables for N way branching
213  ***************************************/
214  for (row=0;row<size;row++) {
215    for (column=0;column<size;column++) {
216      int iColumn = size*row+column;
217      double value = lo[iColumn];
218      if (value<up[iColumn]) {
219        for (int i=0;i<size;i++)
220          build.addColumn(0,NULL,NULL,0.0,1.0,0.0);
221      } else {
222        // fixed
223        // obviously could do better  if we missed out variables
224        int which = ((int) value) - valueOffset;
225        for (int i=0;i<size;i++) {
226          if (i!=which)
227            build.addColumn(0,NULL,NULL,0.0,0.0,0.0);
228          else
229            build.addColumn(0,NULL,NULL,0.0,1.0,0.0);
230        }
231      }
232    }
233  }
234 
235  /***************************************
236     Now rows
237  ***************************************/
238  double values[]={1.0,1.0,1.0,1.0,
239                   1.0,1.0,1.0,1.0,
240                   1.0,1.0,1.0,1.0,
241                   1.0,1.0,1.0,1.0};
242  int indices[MAX_SIZE+1];
243  double rhs = size==9 ? 45.0 : 120.0;
244  for (row=0;row<3*size;row++) {
245    int iStart = starts[row];
246    for (column=0;column<size;column++) 
247      indices[column]=which[column+iStart];
248    build.addRow(size,indices,values,rhs,rhs);
249  }
250  double values2[MAX_SIZE+1];
251  values2[0]=-1.0;
252  for (row=0;row<size;row++) 
253    values2[row+1]=row+valueOffset;
254  // Now add rows for extra variables
255  for (row=0;row<size;row++) {
256    for (column=0;column<size;column++) {
257      int iColumn = row*size+column;
258      int base = size*size + iColumn*size;
259      indices[0]=iColumn;
260      for (int i=0;i<size;i++)
261        indices[i+1]=base+i;
262      build.addRow(size+1,indices,values2,0.0,0.0);
263    }
264  }
265  solver1.loadFromCoinModel(build);
266  build.writeMps("xx.mps");
267 
268  double time1 = CoinCpuTime();
269  solver1.initialSolve();
270  CbcModel model(solver1);
271  model.solver()->setHintParam(OsiDoReducePrint,true,OsiHintTry);
272  model.solver()->setHintParam(OsiDoScale,false,OsiHintTry);
273  /***************************************
274    Add in All different cut generator and All different branching
275    So we will have integers then cut branching then N way branching
276    in reverse priority order
277  ***************************************/
278 
279  // Cut generator
280  CglAllDifferent allDifferent(3*size,starts,which);
281  model.addCutGenerator(&allDifferent,-99,"allDifferent");
282  model.cutGenerator(0)->setWhatDepth(5);
283  CbcObject ** objects = new CbcObject * [4*size*size];
284  int nObj=0;
285  for (row=0;row<3*size;row++) {
286    int iStart = starts[row];
287    objects[row]= new CbcBranchAllDifferent(&model,size,which+iStart);
288    objects[row]->setPriority(2000+nObj); // do after rest satisfied
289    nObj++;
290  }
291  /***************************************
292     Add in N way branching and while we are at it add in cuts
293  ***************************************/
294  CglStored stored;
295  for (row=0;row<size;row++) {
296    for (column=0;column<size;column++) {
297      int iColumn = row*size+column;
298      int base = size*size + iColumn*size;
299      int i;
300      for ( i=0;i<size;i++)
301        indices[i]=base+i;
302      CbcNWay * obj = new CbcNWay(&model,size,indices,nObj);
303      int seq[200];
304      int newUpper[200];
305      memset(newUpper,0,sizeof(newUpper));
306      for (i=0;i<size;i++) {
307        int state=9999;
308        int one=1;
309        int nFix=1;
310        // Fix real variable
311        seq[0]=iColumn;
312        newUpper[0]=valueOffset+i;
313        int kColumn = base+i;
314        int j;
315        // same row
316        for (j=0;j<size;j++) {
317          int jColumn = row*size+j;
318          int jjColumn = size*size+jColumn*size+i;
319          if (jjColumn!=kColumn) {
320            seq[nFix++]=jjColumn; // must be zero
321          }
322        }
323        // same column
324        for (j=0;j<size;j++) {
325          int jColumn = j*size+column;
326          int jjColumn = size*size+jColumn*size+i;
327          if (jjColumn!=kColumn) {
328            seq[nFix++]=jjColumn; // must be zero
329          }
330        }
331        // same block
332        int kRow = row/block_size;
333        kRow *= block_size;
334        int kCol = column/block_size;
335        kCol *= block_size;
336        for (j=kRow;j<kRow+block_size;j++) {
337          for (int jc=kCol;jc<kCol+block_size;jc++) {
338            int jColumn = j*size+jc;
339            int jjColumn = size*size+jColumn*size+i;
340            if (jjColumn!=kColumn) {
341              seq[nFix++]=jjColumn; // must be zero
342            }
343          }
344        }
345        // seem to need following?
346        const int * upperAddress = newUpper;
347        const int * seqAddress = seq;
348        CbcFixVariable fix(1,&state,&one,&upperAddress,&seqAddress,&nFix,&upperAddress,&seqAddress);
349        obj->setConsequence(indices[i],fix);
350        // Now do as cuts
351        for (int kk=1;kk<nFix;kk++) {
352          int jColumn = seq[kk];
353          int cutInd[2];
354          cutInd[0]=kColumn;
355          if (jColumn>kColumn) {
356            cutInd[1]=jColumn;
357            stored.addCut(-COIN_DBL_MAX,1.0,2,cutInd,values);
358          }
359        }
360      }
361      objects[nObj]= obj;
362      objects[nObj]->setPriority(nObj);
363      nObj++;
364    }
365  }
366  model.addObjects(nObj,objects);
367  for (row=0;row<nObj;row++) 
368    delete objects[row];
369  delete [] objects;
370  model.messageHandler()->setLogLevel(1);
371  model.addCutGenerator(&stored,1,"stored");
372  // Say we want timings
373  int numberGenerators = model.numberCutGenerators();
374  int iGenerator;
375  for (iGenerator=0;iGenerator<numberGenerators;iGenerator++) {
376    CbcCutGenerator * generator = model.cutGenerator(iGenerator);
377    generator->setTiming(true);
378  }
379  // Set this to get all solutions (all ones in newspapers should only have one)
380  //model.setCutoffIncrement(-1.0e6);
381  /***************************************
382     Do branch and bound
383  ***************************************/
384  // Do complete search
385  model.branchAndBound();
386  std::cout<<"took "<<CoinCpuTime()-time1<<" seconds, "
387           <<model.getNodeCount()<<" nodes with objective "
388           <<model.getObjValue()
389           <<(!model.status() ? " Finished" : " Not finished")
390           <<std::endl;
391
392
393  /***************************************
394     Print solution and check it is feasible
395     We could modify output so could be imported by spreadsheet
396  ***************************************/
397  if (model.getMinimizationObjValue()<1.0e50) {
398   
399    const double * solution = model.bestSolution();
400    int put=0;
401    for (row=0;row<size;row++) {
402      for (column=0;column<size;column++) {
403        int value = (int) floor(solution[row*size+column]+0.5);
404        assert (value>=lo[put]&&value<=up[put]);
405        // save for later test
406        lo[put++]=value;
407        printf("%d ",value);
408      }
409      printf("\n");
410    }
411    // check valid
412    bool valid=true;
413    // By row
414    for (row=0;row<size;row++) {
415      put=0;
416      for (column=0;column<size;column++) 
417        which[put++]=row*size+column;
418      assert (put==size);
419      int i;
420      for (i=0;i<put;i++) 
421        which[i]=(int) lo[which[i]];
422      std::sort(which,which+put);
423      int last = valueOffset-1;
424      for (i=0;i<put;i++) {
425        int value=which[i];
426        if (value!=last+1)
427          valid=false;
428        last=value;
429      }
430    }
431    // By column
432    for (column=0;column<size;column++) {
433      put=0;
434      for (row=0;row<size;row++) 
435        which[put++]=row*size+column;
436      assert (put==size);
437      int i;
438      for (i=0;i<put;i++) 
439        which[i]=(int) lo[which[i]];
440      std::sort(which,which+put);
441      int last = valueOffset-1;
442      for (i=0;i<put;i++) {
443        int value=which[i];
444        if (value!=last+1)
445          valid=false;
446        last=value;
447      }
448    }
449    // By box
450    for (row=0;row<size;row+=block_size) {
451      for (column=0;column<size;column+=block_size) {
452        put=0;
453        for (int row2=row;row2<row+block_size;row2++) {
454          for (int column2=column;column2<column+block_size;column2++) 
455            which[put++]=row2*size+column2;
456        }
457        assert (put==size);
458        int i;
459        for (i=0;i<put;i++) 
460          which[i]=(int) lo[which[i]];
461        std::sort(which,which+put);
462        int last = valueOffset-1;
463        for (i=0;i<put;i++) {
464          int value=which[i];
465          if (value!=last+1)
466          valid=false;
467          last=value;
468        }
469      }
470    }
471    if (valid) {
472      printf("solution is valid\n");
473    } else {
474      printf("solution is not valid\n");
475      abort();
476    }
477  }
478  return 0;
479}   
Note: See TracBrowser for help on using the repository browser.