# source:trunk/Cbc/examples/sudoku.cpp

Last change on this file was 2469, checked in by unxusr, 2 months ago

formatting

• Property svn:eol-style set to `native`
• Property svn:keywords set to `Author Date Id Revision`
File size: 15.2 KB
Line
1// \$Id: sudoku.cpp 2469 2019-01-06 23:17:46Z forrest \$
5
6#include <cassert>
7#include <iomanip>
8
9#include "CoinPragma.hpp"
10// For Branch and bound
11#include "OsiSolverInterface.hpp"
12#include "CbcModel.hpp"
13#include "CoinModel.hpp"
14// For all different
15#include "CbcBranchCut.hpp"
16#include "CbcBranchActual.hpp"
17#include "CbcBranchAllDifferent.hpp"
18#include "CbcCutGenerator.hpp"
19#include "CglAllDifferent.hpp"
20#include "OsiClpSolverInterface.hpp"
21#include "CglStored.hpp"
22
23#include "CoinTime.hpp"
24
25/************************************************************************
26
27This shows how we can define a new branching method to solve problems with
28all different constraints.
29
30We are going to solve a sudoku problem such as
31
321, , ,4, , ,7, ,
33 ,2, , ,5, , ,8,
348,7,3, , ,6, , ,9
354, , ,7, , ,1, ,
36 ,5, , ,8, , ,2,
37 , ,6, ,4,9, , ,3
387, , ,1, , ,4, ,
39 ,8, ,6,2, , ,5,
40 , ,9, ,7,3, ,1,6
41
42The input should be exported from spreadsheet as a csv file where cells are
43empty unless value is known.
44
45We set up a fake objective and simple constraints to say sum must be 45
46
47and then we add all different branching (CbcBranchAllDifferent)
48and all different cuts (to fix variables) (CglAllDifferent).
49
50CbcBranchAllDifferent is really just an example of a cut branch.  If we wish to stop x==y
51then we can have two branches - one x <= y-1 and the other x >= y+1.  It should be easy for the user to
52make up similar cut branches for other uses.
53
54Note - this is all we need to solve most 9 x 9 puzzles because they seem to solve
55at root node or easily.  To solve 16 x 16 puzzles we need more.  All different cuts
56need general integer variables to be fixed while we can branch so they are just at bounds.
57To get round that we can introduce extra 0-1 variables such that general integer x = sum j * delta j
58and then do N way branching on these (CbcNWay) so that we fix one delta j to 1.  At the same time we use the
59new class CbcConsequence (used in N way branching) which when delta j goes to 1 fixes other variables.
60So it will fix x to the correct value and while we are at it we can fix some delta variables in other
61sets to zero (as per all different rules).  Finally as well as storing the instructions which say if
62delta 11 is 1 then delta 21 is 0 we can also add this in as a cut using new trivial cut class
63CglStored.
64
65************************************************************************/
66
67int main(int argc, const char *argv[])
68{
69  // Get data
70  std::string fileName = "./sudoku_sample.csv";
71  if (argc >= 2)
72    fileName = argv[1];
73  FILE *fp = fopen(fileName.c_str(), "r");
74  if (!fp) {
75    printf("Unable to open file %s\n", fileName.c_str());
76    exit(0);
77  }
78#define MAX_SIZE 16
79  int valueOffset = 1;
80  double lo[MAX_SIZE * MAX_SIZE], up[MAX_SIZE * MAX_SIZE];
81  char line[80];
82  int row, column;
83  /***************************************
84     Read .csv file and see if 9 or 16 Su Doku
85  ***************************************/
86  int size = 9;
87  for (row = 0; row < size; row++) {
88    fgets(line, 80, fp);
89    // Get size of sudoku puzzle (9 or 16)
90    if (!row) {
91      int get = 0;
92      size = 1;
93      while (line[get] >= 32) {
94        if (line[get] == ',')
95          size++;
96        get++;
97      }
98      assert(size == 9 || size == 16);
99      printf("Solving Su Doku of size %d\n", size);
100      if (size == 16)
101        valueOffset = 0;
102    }
103    int get = 0;
104    for (column = 0; column < size; column++) {
105      lo[size * row + column] = valueOffset;
106      up[size * row + column] = valueOffset - 1 + size;
107      if (line[get] != ',' && line[get] >= 32) {
108        // skip blanks
109        if (line[get] == ' ') {
110          get++;
111          continue;
112        }
113        int value = line[get] - '0';
114        if (size == 9) {
115          assert(value >= 1 && value <= 9);
116        } else {
117          assert(size == 16);
118          if (value < 0 || value > 9) {
119            if (line[get] == '"') {
120              get++;
121              value = 10 + line[get] - 'A';
122              if (value < 10 || value > 15) {
123                value = 10 + line[get] - 'a';
124              }
125              get++;
126            } else {
127              value = 10 + line[get] - 'A';
128              if (value < 10 || value > 15) {
129                value = 10 + line[get] - 'a';
130              }
131            }
132          }
133          assert(value >= 0 && value <= 15);
134        }
135        lo[size * row + column] = value;
136        up[size * row + column] = value;
137        get++;
138      }
139      get++;
140    }
141  }
142  int block_size = (int)sqrt((double)size);
143  /***************************************
144     Now build rules for all different
145     3*9 or 3*16 sets of variables
146     Number variables by row*size+column
147  ***************************************/
148  int starts[3 * MAX_SIZE + 1];
149  int which[3 * MAX_SIZE * MAX_SIZE];
150  int put = 0;
151  int set = 0;
152  starts[0] = 0;
153  // By row
154  for (row = 0; row < size; row++) {
155    for (column = 0; column < size; column++)
156      which[put++] = row * size + column;
157    starts[set + 1] = put;
158    set++;
159  }
160  // By column
161  for (column = 0; column < size; column++) {
162    for (row = 0; row < size; row++)
163      which[put++] = row * size + column;
164    starts[set + 1] = put;
165    set++;
166  }
167  // By box
168  for (row = 0; row < size; row += block_size) {
169    for (column = 0; column < size; column += block_size) {
170      for (int row2 = row; row2 < row + block_size; row2++) {
171        for (int column2 = column; column2 < column + block_size; column2++)
172          which[put++] = row2 * size + column2;
173      }
174      starts[set + 1] = put;
175      set++;
176    }
177  }
178  OsiClpSolverInterface solver1;
179
180  /***************************************
181     Create model
182     Set variables to be general integer variables although
183     priorities probably mean that it won't matter
184  ***************************************/
185  CoinModel build;
186  // Columns
187  char name[4];
188  for (row = 0; row < size; row++) {
189    for (column = 0; column < size; column++) {
190      if (row < 10) {
191        if (column < 10)
192          sprintf(name, "X%d%d", row, column);
193        else
194          sprintf(name, "X%d%c", row, 'A' + (column - 10));
195      } else {
196        if (column < 10)
197          sprintf(name, "X%c%d", 'A' + (row - 10), column);
198        else
199          sprintf(name, "X%c%c", 'A' + (row - 10), 'A' + (column - 10));
200      }
201      double value = CoinDrand48() * 100.0;
202      build.addColumn(0, NULL, NULL, lo[size * row + column],
203        up[size * row + column], value, name, true);
204    }
205  }
206  /***************************************
207     Now add in extra variables for N way branching
208  ***************************************/
209  for (row = 0; row < size; row++) {
210    for (column = 0; column < size; column++) {
211      int iColumn = size * row + column;
212      double value = lo[iColumn];
213      if (value < up[iColumn]) {
214        for (int i = 0; i < size; i++)
215          build.addColumn(0, NULL, NULL, 0.0, 1.0, 0.0);
216      } else {
217        // fixed
218        // obviously could do better  if we missed out variables
219        int which = ((int)value) - valueOffset;
220        for (int i = 0; i < size; i++) {
221          if (i != which)
222            build.addColumn(0, NULL, NULL, 0.0, 0.0, 0.0);
223          else
224            build.addColumn(0, NULL, NULL, 0.0, 1.0, 0.0);
225        }
226      }
227    }
228  }
229
230  /***************************************
231     Now rows
232  ***************************************/
233  double values[] = { 1.0, 1.0, 1.0, 1.0,
234    1.0, 1.0, 1.0, 1.0,
235    1.0, 1.0, 1.0, 1.0,
236    1.0, 1.0, 1.0, 1.0 };
237  int indices[MAX_SIZE + 1];
238  double rhs = size == 9 ? 45.0 : 120.0;
239  for (row = 0; row < 3 * size; row++) {
240    int iStart = starts[row];
241    for (column = 0; column < size; column++)
242      indices[column] = which[column + iStart];
243    build.addRow(size, indices, values, rhs, rhs);
244  }
245  double values2[MAX_SIZE + 1];
246  values2[0] = -1.0;
247  for (row = 0; row < size; row++)
248    values2[row + 1] = row + valueOffset;
249  // Now add rows for extra variables
250  for (row = 0; row < size; row++) {
251    for (column = 0; column < size; column++) {
252      int iColumn = row * size + column;
253      int base = size * size + iColumn * size;
254      indices[0] = iColumn;
255      for (int i = 0; i < size; i++)
256        indices[i + 1] = base + i;
257      build.addRow(size + 1, indices, values2, 0.0, 0.0);
258    }
259  }
261  build.writeMps("xx.mps");
262
263  double time1 = CoinCpuTime();
264  solver1.initialSolve();
265  CbcModel model(solver1);
266  model.solver()->setHintParam(OsiDoReducePrint, true, OsiHintTry);
267  model.solver()->setHintParam(OsiDoScale, false, OsiHintTry);
268  /***************************************
269    Add in All different cut generator and All different branching
270    So we will have integers then cut branching then N way branching
271    in reverse priority order
272  ***************************************/
273
274  // Cut generator
275  CglAllDifferent allDifferent(3 * size, starts, which);
277  model.cutGenerator(0)->setWhatDepth(5);
278  CbcObject **objects = new CbcObject *[4 * size * size];
279  int nObj = 0;
280  for (row = 0; row < 3 * size; row++) {
281    int iStart = starts[row];
282    objects[row] = new CbcBranchAllDifferent(&model, size, which + iStart);
283    objects[row]->setPriority(2000 + nObj); // do after rest satisfied
284    nObj++;
285  }
286  /***************************************
287     Add in N way branching and while we are at it add in cuts
288  ***************************************/
289  CglStored stored;
290  for (row = 0; row < size; row++) {
291    for (column = 0; column < size; column++) {
292      int iColumn = row * size + column;
293      int base = size * size + iColumn * size;
294      int i;
295      for (i = 0; i < size; i++)
296        indices[i] = base + i;
297      CbcNWay *obj = new CbcNWay(&model, size, indices, nObj);
298      int seq[200];
299      int newUpper[200];
300      memset(newUpper, 0, sizeof(newUpper));
301      for (i = 0; i < size; i++) {
302        int state = 9999;
303        int one = 1;
304        int nFix = 1;
305        // Fix real variable
306        seq[0] = iColumn;
307        newUpper[0] = valueOffset + i;
308        int kColumn = base + i;
309        int j;
310        // same row
311        for (j = 0; j < size; j++) {
312          int jColumn = row * size + j;
313          int jjColumn = size * size + jColumn * size + i;
314          if (jjColumn != kColumn) {
315            seq[nFix++] = jjColumn; // must be zero
316          }
317        }
318        // same column
319        for (j = 0; j < size; j++) {
320          int jColumn = j * size + column;
321          int jjColumn = size * size + jColumn * size + i;
322          if (jjColumn != kColumn) {
323            seq[nFix++] = jjColumn; // must be zero
324          }
325        }
326        // same block
327        int kRow = row / block_size;
328        kRow *= block_size;
329        int kCol = column / block_size;
330        kCol *= block_size;
331        for (j = kRow; j < kRow + block_size; j++) {
332          for (int jc = kCol; jc < kCol + block_size; jc++) {
333            int jColumn = j * size + jc;
334            int jjColumn = size * size + jColumn * size + i;
335            if (jjColumn != kColumn) {
336              seq[nFix++] = jjColumn; // must be zero
337            }
338          }
339        }
340        // seem to need following?
341        const int *upperAddress = newUpper;
342        const int *seqAddress = seq;
344        obj->setConsequence(indices[i], fix);
345        // Now do as cuts
346        for (int kk = 1; kk < nFix; kk++) {
347          int jColumn = seq[kk];
348          int cutInd[2];
349          cutInd[0] = kColumn;
350          if (jColumn > kColumn) {
351            cutInd[1] = jColumn;
352            stored.addCut(-COIN_DBL_MAX, 1.0, 2, cutInd, values);
353          }
354        }
355      }
356      objects[nObj] = obj;
357      objects[nObj]->setPriority(nObj);
358      nObj++;
359    }
360  }
362  for (row = 0; row < nObj; row++)
363    delete objects[row];
364  delete[] objects;
365  model.messageHandler()->setLogLevel(1);
367  // Say we want timings
368  int numberGenerators = model.numberCutGenerators();
369  int iGenerator;
370  for (iGenerator = 0; iGenerator < numberGenerators; iGenerator++) {
371    CbcCutGenerator *generator = model.cutGenerator(iGenerator);
372    generator->setTiming(true);
373  }
374  // Set this to get all solutions (all ones in newspapers should only have one)
375  //model.setCutoffIncrement(-1.0e6);
376  /***************************************
377     Do branch and bound
378  ***************************************/
379  // Do complete search
380  model.branchAndBound();
381  std::cout << "took " << CoinCpuTime() - time1 << " seconds, "
382            << model.getNodeCount() << " nodes with objective "
383            << model.getObjValue()
384            << (!model.status() ? " Finished" : " Not finished")
385            << std::endl;
386
387  /***************************************
388     Print solution and check it is feasible
389     We could modify output so could be imported by spreadsheet
390  ***************************************/
391  if (model.getMinimizationObjValue() < 1.0e50) {
392
393    const double *solution = model.bestSolution();
394    int put = 0;
395    for (row = 0; row < size; row++) {
396      for (column = 0; column < size; column++) {
397        int value = (int)floor(solution[row * size + column] + 0.5);
398        assert(value >= lo[put] && value <= up[put]);
399        // save for later test
400        lo[put++] = value;
401        printf("%d ", value);
402      }
403      printf("\n");
404    }
405    // check valid
406    bool valid = true;
407    // By row
408    for (row = 0; row < size; row++) {
409      put = 0;
410      for (column = 0; column < size; column++)
411        which[put++] = row * size + column;
412      assert(put == size);
413      int i;
414      for (i = 0; i < put; i++)
415        which[i] = (int)lo[which[i]];
416      std::sort(which, which + put);
417      int last = valueOffset - 1;
418      for (i = 0; i < put; i++) {
419        int value = which[i];
420        if (value != last + 1)
421          valid = false;
422        last = value;
423      }
424    }
425    // By column
426    for (column = 0; column < size; column++) {
427      put = 0;
428      for (row = 0; row < size; row++)
429        which[put++] = row * size + column;
430      assert(put == size);
431      int i;
432      for (i = 0; i < put; i++)
433        which[i] = (int)lo[which[i]];
434      std::sort(which, which + put);
435      int last = valueOffset - 1;
436      for (i = 0; i < put; i++) {
437        int value = which[i];
438        if (value != last + 1)
439          valid = false;
440        last = value;
441      }
442    }
443    // By box
444    for (row = 0; row < size; row += block_size) {
445      for (column = 0; column < size; column += block_size) {
446        put = 0;
447        for (int row2 = row; row2 < row + block_size; row2++) {
448          for (int column2 = column; column2 < column + block_size; column2++)
449            which[put++] = row2 * size + column2;
450        }
451        assert(put == size);
452        int i;
453        for (i = 0; i < put; i++)
454          which[i] = (int)lo[which[i]];
455        std::sort(which, which + put);
456        int last = valueOffset - 1;
457        for (i = 0; i < put; i++) {
458          int value = which[i];
459          if (value != last + 1)
460            valid = false;
461          last = value;
462        }
463      }
464    }
465    if (valid) {
466      printf("solution is valid\n");
467    } else {
468      printf("solution is not valid\n");
469      abort();
470    }
471  }
472  return 0;
473}
Note: See TracBrowser for help on using the repository browser.