#ifndef _DETER_PRED_TRANSDUCER2_H_
#define _DETER_PRED_TRANSDUCER2_H_

#include <boost/function_output_iterator.hpp>
#include <boost/progress.hpp>
#include <vector>
#include <list>
#include <map>
#include <set>

#include <outilex/ostream_dumper.h>

namespace detail {

}; // namespace "detail"



template<typename PredFst, typename SemiRingOps>
class deter_pred_fst2 {

public:

  typedef PredFst pred_fst;
  typedef SemiRingOps semiring_ops;


  typedef typename pred_fst::elem_input_type elem_input_type;
  typedef typename pred_fst::input_type input_type;
  typedef typename pred_fst::output_type output_type;

  typedef std::set<output_type> outputs_type;
 

  struct position {

    int q;
    output_type delayed;

    inline position(xmlNodePtr node) { read_XML(node); }
    inline position(int q_, const output_type & out) : q(q_), delayed(out) {}

    inline bool operator<(const position & b) const {
      if (q != b.q) { return q < b.q; }
      return delayed < b.delayed;
    }
    void dump(std::ostream & os) const {
      os << "(" << q << "," << delayed << ")";
    }
  };


  typedef std::set<position> stateid;



  struct transition {
 
    int to_;
    input_type in_;
    output_type out_;

    inline transition() : to_(-1), in_(), out_() {}
    inline transition(const input_type & in, const output_type & out, int to)
      : to_(to), in_(in), out_(out) {}

    template<typename Transition>
    inline transition(const Transition & t)
      : to_(t.to()), in_(t.in()), out_(t.out()) {}

    inline int to() const { return to_; }

    inline input_type & in() { return in_; }
    inline const input_type & in() const  { return in_; }

    inline output_type & out() { return out_; }
    inline const output_type & out() const { return out_; }

    void dump(std::ostream & os) const {
      os << "(" << in_ << ", " << out_ << ", " << to_ << ")";
    }
  };

  typedef std::vector<transition> transitions;
  typedef typename transitions::iterator trans_iterator;
  typedef typename transitions::const_iterator const_trans_iterator;


  struct state {

    bool final;
    outputs_type final_outputs;
    std::list<transition> ndeter_trans;
    transitions trans;

    state() : final(false), final_outputs(), ndeter_trans(), trans() {}
  };



public:

  /* default constructor */
  inline deter_pred_fst2() : T(), states(), id2state() {}
 
  /* construct from an underlining pred_fst */
  inline deter_pred_fst2(const pred_fst & fst)
    : T(), states(), id2state() { init(fst); }


  /* init from a new ndeter fst */
  void init(const pred_fst & fst) {
    clear();
    T = fst;
    init_fst();
    stateid id;
    id.insert(position(T.start(), semiring_ops::one()));
    add_state(id);
  }

  /* like init but the passed fst is no more usable,
   * avoid costly copy
   */
  void eat(pred_fst & fst) {
    clear();
    T.swap(fst);
    init_fst();
    stateid id;
    id.insert(position(0, semiring_ops::one()));
    add_state(id);
  }


  inline void clear() { T.clear(); states.clear(); id2state.clear(); }

  inline int start() const { return 0; }

  inline bool final(int q) const { return states[q].final; }

  inline const outputs_type & final_outputs(int q) const { return states[q].final_outputs; }
  inline outputs_type & final_outputs(int q) { return states[q].final_outputs; }


  inline const_trans_iterator trans_begin(int q) const { determinize(q); return states[q].trans.begin(); }
  inline const_trans_iterator trans_end(int q) const { determinize(q); return states[q].trans.end(); }


protected:

  typedef std::map<stateid, int> state_map;


  mutable pred_fst T;

  mutable std::vector<state> states;
  mutable state_map id2state;


  /* determinize transition for each state in T */

  void init_fst() {
    std::cerr << "init_fst size = " << T.size() << "\n";
    boost::timer tmr;
    boost::progress_display progress(T.size(), std::cerr);
    for (int q = 0; q < T.size(); ++q) {
      std::list<transition> trans_list(T.trans_begin(q), T.trans_end(q));

      /*
      std::cerr << "avant:\n";
      std::for_each(trans_list.begin(), trans_list.end(), ostream_dumper(std::cerr, "\n-> "));
      std::cerr << "\n";
      */
      deter_trans_list(trans_list);
      /*
      std::cerr << "apres:\n";
      std::for_each(trans_list.begin(), trans_list.end(), ostream_dumper(std::cerr, "\n-> "));
      std::cerr << "\n";
      */
      T.trans(q).clear();
      T.trans(q).insert(T.trans_end(q), trans_list.begin(), trans_list.end());
      ++progress;
    }
    std::cerr << "out of init_fst, " << tmr.elapsed() << "s.\n";
  }


  void init_state(int q, const stateid & id) const {

    //  std::cerr << "\ninit_state(" << q << ")\n";

    states[q].final = false;

    std::list<transition> & ndeter_trans = states[q].ndeter_trans;

    for (typename stateid::const_iterator it = id.begin(); it != id.end(); ++it) {
    
      //std::cerr << "proceed with q= " << it->q << "\n";

      output_type delayed = it->delayed.delay();

      if (T.final(it->q)) {
        states[q].final = true;
        if (! T.final_outputs(it->q).empty()) {
          for (typename pred_fst::outputs_type::const_iterator it2 = T.final_outputs(it->q).begin();
               it2 != T.final_outputs(it->q).end(); ++it2) {
            states[q].final_outputs.insert(semiring_ops::mult(delayed, *it2));
          }
        } else {
          std::cerr << "warning final state with no final outputs\n";
          states[q].final_outputs.insert(delayed);
        }
      }

      std::list<transition> trans(T.trans_begin(it->q), T.trans_end(it->q));

#warning "optimisation??"

//#define TEST_OPTIM 1
#undef TEST_OPTIM

#ifdef TEST_OPTIM
      T.trans(it->q).clear();
#endif
      /*
      std::cerr << "output trans = \n";
      std::for_each(trans.begin(), trans.end(), ostream_dumper(std::cerr, "\n"));
      */

      for (typename std::list<transition>::iterator tr1 = trans.begin(); tr1 != trans.end(); ++tr1) {

        for (typename std::list<transition>::iterator tr2 = ndeter_trans.begin();
             tr2 != ndeter_trans.end(); ++tr2) {

          if (tr1->in() == tr2->in()) { 
            // set are equals, so label is necessaraly equal or disjoint with all
            // label in ndeter_trans (which are all equal or disjoint two per two) : break
            break;
          }
        
          input_type i = tr1->in();
          i &= tr2->in();

          if (! i) { // set are disjoint : continue to next trans
            continue;
          }

          // compute set difference and put the result to the end of the trans lists,
          // so it is proceed later
 
          input_type::minus(tr1->in(), i, 
                            boost::make_function_output_iterator(make_trans_outputer(*tr1,
                                                                                     std::back_inserter(trans))));
          // insert difference before current pos in ndeter_trans,
          // cause we already know it is disjoint from i
          input_type::minus(tr2->in(), i, 
                            boost::make_function_output_iterator(make_trans_outputer(*tr2,
                                                                                     std::inserter(ndeter_trans, tr2))));
          tr1->in() = i;
          tr2->in() = i;
        }
 
#ifdef TEST_OPTIM
        T.trans(it->q).push_back(*tr1);
#endif
        // add the delayed output to the output of the current trans
      
        tr1->out() = semiring_ops::mult(delayed, tr1->out());
      }

      ndeter_trans.splice(ndeter_trans.begin(), trans);
    }
    /*
    std::cerr << "out of initstate trans = \n";
    std::for_each(ndeter_trans.begin(), ndeter_trans.end(), ostream_dumper(std::cerr, "\n"));
    std::cerr << "bye\n\n";
    */
  }


  int add_state(const stateid & id) const {
  
    typename state_map::iterator it = id2state.find(id);
    if (it != id2state.end()) { return (*it).second; }

    int res = states.size();
    states.resize(res + 1);
    id2state[id] = res;

    //std::cerr << "add_state: id = " << res << std::endl;
    init_state(res, id);
    return res;
  }



  void determinize(int q) const {

    if (states[q].ndeter_trans.empty()) { return; }

    //    std::cerr << "\ndeterminize(" << q << ")\n";

    std::list<transition> ndeter_trans;
    ndeter_trans.swap(states[q].ndeter_trans);

    /*
    std::cerr << "ndeter_trans =\n";
    std::for_each(ndeter_trans.begin(), ndeter_trans.end(), ostream_dumper(std::cerr, "\n"));
    */

    /* now, ndeter_trans contains all the output transition (developped) of the current
     * state 
     */

    while (! ndeter_trans.empty()) {
      input_type label = ndeter_trans.front().in();
      //std::cerr << "label = " << label << "\n";
      stateid idto;
      NEXT(label, ndeter_trans, idto);
      //std::cerr << "stateid = ";
      //std::for_each(idto.begin(), idto.end(), ostream_dumper(std::cerr, "\n"));
      output_type prefix;
      LCP(idto, prefix);
      //cerr << "prefix = " << prefix << '\n';
      int to = add_state(idto);
      //cerr << "to = " << to << "\n";
      states[q].trans.push_back(transition(label, prefix, to));
    }

  /*
    std::cerr << "deter_trans = \n";
    std::for_each(trans_begin(q), trans_end(q), ostream_dumper(std::cerr, "\n"));
    std::cerr << "out of determinize(" << q << ")\n\n";
 */
  }



  /* compute the stateid which is reached by reading label from the translist
   * all trans labeled with label are removed from list
   */

  static void NEXT(const input_type & label, std::list<transition> & trans, stateid & res) {

    /*
    std::cerr << "NEXT(" << label << ")\n";
    std::cerr << "transsize = " << trans.size() << "\n";
    */

    typename std::list<transition>::iterator tr = trans.begin();

    while (tr != trans.end()) {
      //std::cerr << "tr->in = " << tr->in() << "\n";
      if (tr->in() == label) {
        //  std::cerr << "SAME\n";
        res.insert(position(tr->to(), tr->out()));
        tr = trans.erase(tr); 
      } else { ++tr; }
    }
    //std::cerr << "out of NEXT, trans size = " << trans.size() << "\n";
  }


  /* compute the longest common prefix of the delayed output in stateid
   * and erase it from those output
   */

  static void LCP(stateid & id, output_type & prefix) {
    prefix = semiring_ops::zero();
    for (typename stateid::iterator it = id.begin(); it != id.end(); ++it) {
      prefix = semiring_ops::plus(prefix, it->delayed);
    }
    stateid nid;
    for (typename stateid::iterator it = id.begin(); it != id.end(); ++it) {
      nid.insert(nid.end(), position(it->q, semiring_ops::minus(it->delayed, prefix)));
    }
    id.swap(nid);
  }

protected:

  /* helper functions and classes */

  static void deter_trans_list(std::list<transition> & trans) {

    typename std::list<transition>::iterator t1, t2;
 
    for (t1 = trans.begin(); t1 != trans.end(); ++t1) {
 
      t2 = t1;
 
      for (++t2; t2 != trans.end(); ++t2) {
 
        if (t1->in() == t2->in()) { continue; } // equal sets
        
        input_type i = t1->in();
        i &= t2->in();

        if (! i) { continue; } // disjoint sets

        input_type::minus(t1->in(), i, 
                          boost::make_function_output_iterator(make_trans_outputer(*t1,
                                                                                   std::back_inserter(trans))));
        input_type::minus(t2->in(), i, 
                          boost::make_function_output_iterator(make_trans_outputer(*t2,
                                                                                   std::back_inserter(trans))));

        t1->in() = i;
        t2->in() = i;
      }
    }
  }


  template<typename Transition, typename OutputIterator>
  struct trans_outputer {

    trans_outputer(const Transition & t, OutputIterator o)
      : tr(t), out(o) {}

    void operator()(const input_type & m) {
      tr.in() = m;
      *out = tr; ++out;
    }

    Transition tr;
    OutputIterator out;
  };

  template<typename Transition, typename OutputIterator>
  static inline trans_outputer<Transition, OutputIterator>
  make_trans_outputer(const Transition & t, OutputIterator o) {
    return trans_outputer<Transition, OutputIterator>(t, o);
  }

};


template<typename PredFst, typename Ops>
std::ostream & operator<<(std::ostream &os, const typename deter_pred_fst2<PredFst, Ops>::transition & tr) {
  tr.dump(os); return os;
}
#endif

