Changeset 2983


Ignore:
Timestamp:
Oct 21, 2013 3:47:31 AM (6 years ago)
Author:
bradbell
Message:

First version that properly conditionally skips an entire atomic function.

Location:
branches/opt_cond_exp
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/opt_cond_exp/cppad/local/num_skip.hpp

    r2978 r2983  
    9595                        play_.forward_cskip(op, arg, i_op, i_var);
    9696                //
    97                 if( cskip_op_[i_op] & (NumRes(op) > 0) )
    98                         n_skip++;
     97                if( cskip_op_[i_op] )
     98                        n_skip += NumRes(op);
    9999        }
    100100        return n_skip;
  • branches/opt_cond_exp/cppad/local/optimize.hpp

    r2982 r2983  
    256256        /// index for right comparison operand
    257257        size_t right;
     258        /// set of variables to skip on true
     259        CppAD::vector<size_t> skip_var_true;
     260        /// set of variables to skip on false
     261        CppAD::vector<size_t> skip_var_false;
    258262        /// set of operations to skip on true
    259         CppAD::vector<size_t> skip_on_true;
     263        CppAD::vector<size_t> skip_op_true;
    260264        /// set of operations to skip on false
    261         CppAD::vector<size_t> skip_on_false;
     265        CppAD::vector<size_t> skip_op_false;
     266        /// size of skip_op_true
     267        size_t n_op_true;
     268        /// size of skip_op_false
     269        size_t n_op_false;
    262270        /// index in the argument recording of first argument for this CSkipOp
    263271        size_t i_arg;
     
    275283        /// index of the beginning of the atomic call sequence; i.e.,
    276284        /// the first UserOp.
    277         size_t new_op_begin;
     285        size_t op_begin;
    278286        /// If this is a conditional connection, this is one more than the
    279287        ///  operator index of the ending of the atomic call sequence; i.e.,
    280288        /// the second UserOp.
    281         size_t new_op_end;
     289        size_t op_end;
    282290};
    283291
     
    15771585
    15781586                        // Operations where there is noting to do
    1579                         case BeginOp:
    15801587                        case ComOp:
    15811588                        case EndOp:
    1582                         case InvOp:
    15831589                        case ParOp:
    15841590                        case PriOp:
    15851591                        break;  // --------------------------------------------
     1592
     1593                        // Operators that never get removed
     1594                        case BeginOp:
     1595                        case InvOp:
     1596                        tape[i_var].connect_type = yes_connected;
     1597                        break;
    15861598
    15871599                        // Load using a parameter index
     
    16401652                                optimize_user_info info;
    16411653                                info.connect_type = not_connected;
     1654                                info.op_end       = i_op + 1;
    16421655                                user_info.push_back(info);
     1656                               
    16431657                        }
    16441658                        else
     
    16511665                                //
    16521666                                CPPAD_ASSERT_UNKNOWN( user_curr + 1 == user_info.size() );
    1653                                 user_curr         = user_info.size();
     1667                                user_info[user_curr].op_begin = i_op;
     1668                                user_curr                     = user_info.size();
    16541669               }
    16551670                        break;
     
    16751690                        --user_j;
    16761691                        if( ! user_s[user_j].empty() )
    1677                         {       tape[arg[0]].connect_type          = yes_connected;
    1678                                 user_info[user_curr].connect_type  = yes_connected;
    1679                         }
     1692                                tape[arg[0]].connect_type =
     1693                                        user_info[user_curr].connect_type;
    16801694                        if( user_j == 0 )
    16811695                                user_state = user_start;
     
    17591773
    17601774        // Determine which variables can be conditionally skipped
    1761         // 2DO: Perhaps we should change NumRes( UserOp ) = 1 , so it
    1762         // also gets a separate skip_on_true and skip_on false value
    17631775        for(i = 0; i < num_var; i++)
    17641776        {       if( tape[i].connect_type == cexp_true_connected )
    17651777                {       j = tape[i].connect_index;
    1766                         cskip_info[j].skip_on_false.push_back(i);
     1778                        cskip_info[j].skip_var_false.push_back(i);
    17671779                }
    17681780                if( tape[i].connect_type == cexp_false_connected )
    17691781                {       j = tape[i].connect_index;
    1770                         cskip_info[j].skip_on_true.push_back(i);
     1782                        cskip_info[j].skip_var_true.push_back(i);
    17711783                }
    17721784        }
     1785        // Move skip information from user_info to cskip_info
     1786        for(i = 0; i < user_info.size(); i++)
     1787        {       if( user_info[i].connect_type == cexp_true_connected )
     1788                {       j = user_info[i].connect_index;
     1789                        cskip_info[j].n_op_false =
     1790                                user_info[i].op_end - user_info[i].op_begin;
     1791                }
     1792                if( user_info[i].connect_type == cexp_false_connected )
     1793                {       j = user_info[i].connect_index;
     1794                        cskip_info[j].n_op_true =
     1795                                user_info[i].op_end - user_info[i].op_begin;
     1796                }
     1797        }
     1798
    17731799        // Sort the conditional skip information by the maximum of the
    17741800        // index for the left and right comparision operands
     
    18581884                if( skip )
    18591885                {       j     = cskip_info_order[cskip_info_next];
    1860                         skip &= cskip_info[j].left < i_var;
    1861                         skip &= cskip_info[j].right < i_var;
     1886                        if( NumRes(op) > 0 )
     1887                        {       skip &= cskip_info[j].left < i_var;
     1888                                skip &= cskip_info[j].right < i_var;
     1889                        }
     1890                        else
     1891                        {       skip &= cskip_info[j].left <= i_var;
     1892                                skip &= cskip_info[j].right <= i_var;
     1893                        }
    18621894                }
    18631895                if( skip )
    18641896                {       cskip_info_next++;
    1865                         skip &= cskip_info[j].skip_on_true.size() > 0 ||
    1866                                         cskip_info[j].skip_on_false.size() > 0;
     1897                        skip &= cskip_info[j].skip_var_true.size() > 0 ||
     1898                                        cskip_info[j].skip_var_false.size() > 0;
    18671899                        if( skip )
    18681900                        {       optimize_cskip_info info = cskip_info[j];
    18691901                                CPPAD_ASSERT_UNKNOWN( NumRes(CSkipOp) == 0 );
    1870                                 CPPAD_ASSERT_UNKNOWN( info.left < i_var );
    1871                                 CPPAD_ASSERT_UNKNOWN( info.right < i_var );
    1872                                 size_t n_true  = info.skip_on_true.size();
    1873                                 size_t n_false = info.skip_on_false.size();
     1902                                size_t n_true  =
     1903                                        info.skip_var_true.size() + info.n_op_true;
     1904                                size_t n_false =
     1905                                        info.skip_var_false.size() + info.n_op_false;
    18741906                                size_t n_arg   = 7 + n_true + n_false;
    18751907                                // reserve space for the arguments to this operator but
     
    22702302                                CPPAD_ASSERT_UNKNOWN( user_curr > 0 );
    22712303                                user_curr--;
    2272                                 user_info[user_curr].new_op_begin = i_op;
     2304                                user_info[user_curr].op_begin = rec->num_rec_op();
    22732305                        }
    22742306                        else
    22752307                        {       user_state = user_start;
    2276                                 user_info[user_curr].new_op_end = i_op;
     2308                                user_info[user_curr].op_end = rec->num_rec_op() + 1;
    22772309                        }
    22782310                        // user_index, user_id, user_n, user_m
     
    23532385# endif
    23542386
     2387        // Move skip information from user_info to cskip_info
     2388        for(i = 0; i < user_info.size(); i++)
     2389        {       if( user_info[i].connect_type == cexp_true_connected )
     2390                {       j = user_info[i].connect_index;
     2391                        k = user_info[i].op_begin;
     2392                        while(k < user_info[i].op_end)
     2393                                cskip_info[j].skip_op_false.push_back(k++);
     2394                }
     2395                if( user_info[i].connect_type == cexp_false_connected )
     2396                {       j = user_info[i].connect_index;
     2397                        k = user_info[i].op_begin;
     2398                        while(k < user_info[i].op_end)
     2399                                cskip_info[j].skip_op_true.push_back(k++);
     2400                }
     2401        }
     2402
    23552403        // fill in the arguments for the CSkip operations
    23562404        CPPAD_ASSERT_UNKNOWN( cskip_info_next == cskip_info.size() );
     
    23582406        {       optimize_cskip_info info = cskip_info[i];
    23592407                if( info.i_arg > 0 )
    2360                 {       size_t n_true  = info.skip_on_true.size();
    2361                         size_t n_false = info.skip_on_false.size();
     2408                {       CPPAD_ASSERT_UNKNOWN( info.n_op_true==info.skip_op_true.size() );
     2409                        CPPAD_ASSERT_UNKNOWN(info.n_op_false==info.skip_op_false.size());
     2410                        size_t n_true  =
     2411                                info.skip_var_true.size() + info.skip_op_true.size();
     2412                        size_t n_false =
     2413                                info.skip_var_false.size() + info.skip_op_false.size();
    23622414                        size_t i_arg   = info.i_arg;
    23632415                        rec->ReplaceArg(i_arg++, info.cop   );
     
    23672419                        rec->ReplaceArg(i_arg++, n_true     );
    23682420                        rec->ReplaceArg(i_arg++, n_false    );
    2369                         for(j = 0; j < n_true; j++)
    2370                         {       i_var = info.skip_on_true[j];
     2421                        for(j = 0; j < info.skip_var_true.size(); j++)
     2422                        {       i_var = info.skip_var_true[j];
    23712423                                CPPAD_ASSERT_UNKNOWN( tape[i_var].new_op > 0 );
    23722424                                rec->ReplaceArg(i_arg++, tape[i_var].new_op );
    23732425                        }
    2374                         for(j = 0; j < n_false; j++)
    2375                         {       i_var = info.skip_on_false[j];
     2426                        for(j = 0; j < info.skip_op_true.size(); j++)
     2427                        {       i_op = info.skip_op_true[j];
     2428                                rec->ReplaceArg(i_arg++, i_op);
     2429                        }
     2430                        for(j = 0; j < info.skip_var_false.size(); j++)
     2431                        {       i_var = info.skip_var_false[j];
    23762432                                CPPAD_ASSERT_UNKNOWN( tape[i_var].new_op > 0 );
    23772433                                rec->ReplaceArg(i_arg++, tape[i_var].new_op );
     2434                        }
     2435                        for(j = 0; j < info.skip_op_false.size(); j++)
     2436                        {       i_op = info.skip_op_false[j];
     2437                                rec->ReplaceArg(i_arg++, i_op);
    23782438                        }
    23792439                        rec->ReplaceArg(i_arg++, n_true + n_false);
  • branches/opt_cond_exp/test_more/optimize.cpp

    r2979 r2983  
    1717
    1818namespace {
     19        // -------------------------------------------------------------------
     20        // Test conditional optimizing out call to an atomic function
     21        void k_algo(
     22                const CppAD::vector< CppAD::AD<double> >& x ,
     23                      CppAD::vector< CppAD::AD<double> >& y )
     24        {       y[0] = x[0] + x[1]; }
     25
     26        void h_algo(
     27                const CppAD::vector< CppAD::AD<double> >& x ,
     28                      CppAD::vector< CppAD::AD<double> >& y )
     29        {       y[0] = x[0] - x[1]; }
     30
     31        bool atomic_cond_exp(void)
     32        {       bool ok = true;
     33                typedef CppAD::vector< CppAD::AD<double> > ADVector;
     34
     35                // Create a checkpoint version of the function g
     36                ADVector ax(2), ag(1), ah(1), ay(1);
     37                ax[0] = 0.;
     38                ax[1] = 1.;
     39                CppAD::checkpoint<double> k_check("k_check", k_algo, ax, ag);
     40                CppAD::checkpoint<double> h_check("h_check", h_algo, ax, ah);
     41
     42                // independent variable vector
     43                Independent(ax);
     44
     45                // atomic function calls that get conditionally used
     46                k_check(ax, ag);
     47                h_check(ax, ah);
     48
     49                // conditional expression
     50                ay[0] = CondExpLt(ax[0], ax[1], ag[0], ah[0]);
     51       
     52                // create function object f : ax -> ay
     53                CppAD::ADFun<double> f;
     54                f.Dependent(ax, ay);
     55       
     56                // use zero order to evaluate f(3,4)
     57                CppAD::vector<double>  x( f.Domain() );
     58                CppAD::vector<double>  y( f.Range() );
     59                x[0] = 3.;
     60                x[1] = 4.;
     61                y    = f.Forward(0, x);
     62                ok  &= y[0] == x[0] + x[1];
     63
     64                // before optimize
     65                ok  &= f.number_skip() == 0;
     66
     67                // now optimize the operation sequence
     68                f.optimize();
     69
     70                // optimized zero order forward
     71                x[0] = 4.;
     72                x[1] = 3.;
     73                y    = f.Forward(0, x);
     74                ok   = y[0] == x[0] - x[1];
     75
     76                // after optimize can skip either call to g or call to h
     77                ok  &= f.number_skip() == 1;
     78       
     79                return ok;
     80        }
    1981        // -------------------------------------------------------------------
    2082        // Test of optimizing out arguments to an atomic function
     
    12101272bool optimize(void)
    12111273{       bool ok = true;
     1274        // check optimizing out entire atomic function
     1275        ok     &= atomic_cond_exp();
    12121276        // check optimizing out atomic arguments
    12131277        ok     &= atomic_arguments();
Note: See TracChangeset for help on using the changeset viewer.