#include <sstream>
#include <boost/lexical_cast.hpp>

#include <outilex/xml.h>
#include <outilex/syntagm_fst.h>


using namespace std;
using namespace boost;

struct match_entry {

  const lexical_entry & e;
  match_entry(const lexical_entry & entry) : e(entry) {
    //cerr << "match_entry(e=" << entry << ")\n";
  }

  bool operator()(const lexical_mask & m) {
    //cerr << "compare with " << m << endl; cerr << "result=" << lexical_mask::match(e, m) << endl;
    return lexical_mask::match(e, m);
  }
  bool operator()(const syntagm_fst::transition & trans) {
    return lexical_mask::match(e, trans.in());
  }
};


template<typename OutputIterator>
void syntagm_fst::extract_matching_trans(const lexical_entry & e,
                                         ndeter_transitions & ntrans,
                                         OutputIterator out,
                                         lexical_mask & m,
                                         const std::set<output_type> & delayeds,
                                         output_type & prefix) {

  //cerr << "extracting_trans for " << e << "(prefix=" << prefix <<")\n";

  ndeter_transitions::iterator itr;

  for (itr = ntrans.begin(); itr != ntrans.end();) {

    //cerr << "tr.in = " << itr->in() << "trout= " << itr->out() << endl;
    //cerr << "m=" << m << endl;

    ndeter_transition & tr = *itr;

    if (lexical_mask::match(e, tr.in())) { // curr transition matches with e

//      cerr << "extract matches: " << tr.in() << "MATCHES" << e << "!\n";

      // on découpe le label in = (in inter m) u (in minus m)

      m &= tr.in();
 
      // on insert (in minus m) dans les ndeter_trans
      lexical_mask::minus(tr.in(), m, ndeter_trans_inserter(ntrans, itr, tr));

      // on place (in inter m) dans out
      tr.in() = m;

      *out = tr;
      out++;

      /* compute longer common prefix of matchin trans */


      for (std::set<output_type>::const_iterator it = delayeds.begin(); it != delayeds.end(); ++it) {
        prefix = output_type::plus(prefix, output_type::mult(*it, tr.out()));
      }

      // suprime (in inter m) des ndeter trans

      itr = ntrans.erase(itr);

    } else { // tr doesn't match

      // compute (m minus tr.in)

      std::vector<lexical_mask> ms;
      lexical_mask::minus(m, tr.in(), back_inserter(ms));

      // reduce mask m so that it is disjoint with tr.in

      vector<lexical_mask>::iterator it = find_if(ms.begin(), ms.end(), match_entry(e));
      assert(it != ms.end());
      m = *it;

      itr++;
    }
  }
}



syntagm_fst::trans_iterator syntagm_fst::find_matching_trans(int from, const lexical_entry & e) {


  //cerr << "find_matching_trans(" << get_name() << ", " << from << ", " << e << ")\n";

  state & Q = states[from];
  int res;


  /* first lookup in determinized transitions */

  {
    trans_iterator tr = find_if(trans_begin(from), trans_end(from), match_entry(e));
    if (tr != trans_end(from)) {
      return tr;
    }
  }

  /* extract matching trans,
   * and compute label, the bigger lexical_mask matching with e and disjoint
   * with all labels of no matching transitions
   */

  typedef std::vector<std::pair<int, ndeter_transitions> > trans_vec;
  trans_vec matching_trans;

  lexical_mask label; // == LEXIC
  output_type prefix = output_type::zero();

  for (std::map<int, state::info>::iterator it = Q.infos.begin(); it != Q.infos.end(); ++it) {

    int q = (*it).first;
    ndeter_transitions & ntrans = (*it).second.ndeter_trans;
    const std::set<output_type> & delayeds = (*it).second.delayeds;

    int no = matching_trans.size();
    matching_trans.resize(no + 1);

    matching_trans[no].first = q;
    ndeter_transitions & mtrans = matching_trans[no].second;

    extract_matching_trans(e, ntrans, front_inserter(mtrans), label, delayeds, prefix);

    if (mtrans.empty()) { // no matching trans leaving from q
      matching_trans.pop_back();
    }
  }


  // label == label of the determinist transition
  // prefix == longer common prefix


  if (matching_trans.empty()) { // no matching transition at all
    
    /* label is the lexical mask matching with e and disjoint with
     * all masks in ndeter trans.
     * we cut label so that it becomes disjoint with all mask in deter trans (and still matching e)
     * then we add a trans to deter trans, labelled with label and pointing to state -1
     */

//    cerr << "no matching trans at all: cuting " << label << "\n";
    
    for (trans_iterator tr = Q.trans.begin(); tr != Q.trans.end(); ++tr) { // for each trans in deter trans

//      cerr << "cutting " << label << " with " << tr->in() << endl;

      vector<lexical_mask> ms;
      lexical_mask::minus(label, tr->in(), back_inserter(ms));
 
      // find the label matching with e
      vector<lexical_mask>::iterator it = find_if(ms.begin(), ms.end(), match_entry(e));
      assert(it != ms.end());
      label = *it;
      //cerr << "label=" <<label << endl;
    }
//    cerr << "find_matching_trans: label=" << label << " to = " << -1 << endl;

    return Q.trans.insert(Q.trans.end(), transition(label, output_type::one(), -1));
  }

  /*
   *  dirty stuffs :
   *  * reinject subset of matching trans in ndeter trans
   *  * compute stateid on the fly
   */

  lexical_mask m = label;
  vector<lexical_mask> min;

  stateid to_id;


  for (int no = matching_trans.size() - 1; no >= 0; --no) {

    int q = matching_trans[no].first;
    ndeter_transitions & ntrans = Q.infos[q].ndeter_trans;
    ndeter_transitions & mtrans = matching_trans[no].second;

    const set<output_type> & delayeds = Q.infos[q].delayeds;

    for (ndeter_transitions::iterator itr = mtrans.begin(); itr  != mtrans.end(); ++itr) {

      ndeter_transition & tr = *itr;

      assert(in(m, tr.in())); // m should be subset of tr.in) or m == in

      /* reinject (tr.in minus label) dans ndeter trans
       * (tr.in minus label) == (tr.in minus m) union (m minus label)
       * with (m minus label) == min
       */

      lexical_mask::minus(tr.in(), m, back_inserter(min));
      copy(min.begin(), min.end(), ndeter_trans_inserter(ntrans, ntrans.begin(), tr));

      m = tr.in();


      /* compute stateid */

      for (set<output_type>::const_iterator it = delayeds.begin(); it != delayeds.end(); ++it) {
        output_type mydelay = output_type::minus(output_type::mult(*it, tr.out()), prefix);
        to_id.insert(position(tr.to(), mydelay));
      }
    }
  }


  /* adding new determinist transition */

  int to = add_state(to_id);
  return add_trans(from, label, prefix, to);
}


syntagm_fst::trans_iterator syntagm_fst::add_trans(int from, const input_type & in,
                                                   const output_type & out, int to) {

  return states[from].trans.insert(states[from].trans.end(), transition(in, out, to));
}


void syntagm_fst::init_state(int q, const stateid & id) {

  // WARNING: don't use reference, the address of states[q] is susceptible to change
  // during the function execution, since the vector of states can grow
  // state & Q = states[q];
  // assert(& states[q] == & Q);

  states[q].final = false;

  //assert(Q.infos.empty());
  //assert(Q.synt_trans.empty());

  for (std::set<position>::const_iterator it = id.begin(); it != id.end(); ++it) {

    state::info & inf = states[q].infos[it->q];

    output_type delayed = it->delayed.delay(); // increment delayed output pos

    inf.delayeds.insert(delayed); // increment delayed by one

    if (inf.ndeter_trans.empty()) {
      inf.ndeter_trans.insert(inf.ndeter_trans.begin(), A.trans_begin(it->q), A.trans_end(it->q)); 
    }

#warning syntagm_fst::init_state: non determinize synt transitions

//    cerr << "init_state adding synt_trans:\n";

    for (syntagm_pattern::const_synt_trans_iterator str = A.synt_trans_begin(it->q); str != A.synt_trans_end(it->q); ++str) {

//      cerr << "trans in=" << str->in() << endl;

#warning not so sure ?

      stateid id;
      id.insert(position(str->to(), output_type::one()));

      int to = add_state(id);
      //assert(& states[q] == & Q);
      states[q].synt_trans.push_back(synt_transition(str->in(),
                                                     output_type::mult(delayed, str->out()), to));
    }

    if (A.final(it->q)) {
      states[q].final = true;
      states[q].final_outputs.insert(delayed);
    }
  }
}


int syntagm_fst::add_state(const stateid & id) {

  state_map::iterator it = id2state.find(id);
  if (it != id2state.end()) { return (*it).second; }

  int res = states.size();
  states.resize(res + 1);
  id2state[id] = res;

  init_state(res, id);
  return res;
}

void syntagm_fst::clear() {
  name.clear();
  A.clear();
  states.clear();
  id2state.clear();
}


void syntagm_fst::init(const syntagm_pattern & pattern) {

  clear();
  A = pattern;

  name = A.get_name();

  stateid id;
  id.insert(position(A.start(), output_type::one()));
  add_state(id);
}



/* XML Serialisation */

void syntagm_fst::position::write_XML(xmlwriter & writer) const {
  writer.start_element("position");
  writer.write_attribute("q", lexical_cast<string>(q));
  delayed.write_XML(writer);
  writer.end_element();
}

void syntagm_fst::position::read_XML(xmlNodePtr node) {

  char * text = (char *) xmlGetProp(node, (const xmlChar *) "q");
  istringstream(text) >> q;
  xmlFree(text);
  
  node = node->xmlChildrenNode;
  while (node) {
    if (xmlStrcmp(node->name, output_type::xml_name()) == 0) {
      delayed.read_XML(node);
      return;
    }
  }
}


void syntagm_fst::transition::write_XML(xmlwriter & writer) const {
  writer.start_element("transition");
  writer.write_attribute("to", lexical_cast<string>(to_));
  writer.start_element("in");
  in_.write_XML(writer);
  writer.end_element();
  writer.start_element("out");
  out_.write_XML(writer);
  writer.end_element();
  writer.end_element();
}

void syntagm_fst::transition::read_XML(xmlNodePtr node, ling_def * lingdef) {

  char * text = xmlGetProp(node, "to");
  istringstream(text) >> to_;
  xmlFree(text);

  node = node->xmlChildrenNode;
  while (node) {

    if (xmlStrcmp(node->name, (const xmlChar *) "in") == 0) {

      xmlNodePtr cur = node->xmlChildrenNode;

      while (cur) {
        if (xmlStrcmp(cur->name, input_type::xml_name()) == 0) {
          in_.read_XML(cur, lingdef);
        }
        cur = cur->next;
      }

    } else if (xmlStrcmp(node->name, (const xmlChar *) "out") == 0) {

      xmlNodePtr cur = node->xmlChildrenNode;

      while (cur) {
        if (xmlStrcmp(cur->name, output_type::xml_name()) == 0) {
          out_.read_XML(cur);
        }
        cur = cur->next;
      }
    }

    node = node->next;
  }
}


void syntagm_fst::synt_transition::write_XML(xmlwriter & writer) const {
  writer.start_element("synt_transition");
  writer.write_attribute("to", lexical_cast<string>(to_));
  writer.start_element("in");
  in_.write_XML(writer);
  writer.end_element();
  writer.start_element("out");
  out_.write_XML(writer);
  writer.end_element();
  writer.end_element();
}


void syntagm_fst::synt_transition::read_XML(xmlNodePtr node, ling_def * lingdef) {

  char * text = xmlGetProp(node, "to");
  istringstream(text) >> to_;
  xmlFree(text);

  node = node->xmlChildrenNode;
  while (node) {

    if (xmlStrcmp(node->name, (const xmlChar *) "in") == 0) {

      xmlNodePtr cur = node->xmlChildrenNode;

      while (cur) {
        if (xmlStrcmp(cur->name, synt_input_type::xml_name()) == 0) {
          in_.read_XML(cur, lingdef);
        }
        cur = cur->next;
      }

    } else if (xmlStrcmp(node->name, (const xmlChar *) "out") == 0) {

      xmlNodePtr cur = node->xmlChildrenNode;

      while (cur) {
        if (xmlStrcmp(cur->name, output_type::xml_name()) == 0) {
          out_.read_XML(cur);
        }
        cur = cur->next;
      }
    }

    node = node->next;
  }
}




void syntagm_fst::state::info::write_XML(xmlwriter & writer, int id) const {

  writer.start_element("info");
  writer.write_attribute("id", lexical_cast<string>(id));

  write_outputs_XML(writer, delayeds);

  for_each(ndeter_trans.begin(), ndeter_trans.end(), XML_Writer(writer));

  writer.end_element();
}

void syntagm_fst::state::info::read_XML(xmlNodePtr node, ling_def * lingdef, int & id) {

  char * text = (char *) xmlGetProp(node, (const xmlChar *) "id");
  istringstream(text) >> id;
  xmlFree(text);

  read_XML(node, lingdef);
}

void syntagm_fst::state::info::read_XML(xmlNodePtr node, ling_def * lingdef) {

  //cerr << "infos read_XML\n";

  delayeds.clear(), ndeter_trans.clear(); // ndeter_synt_trans.clear();

  node = node->xmlChildrenNode;

  while (node) {
    if (xmlStrcmp(node->name, outputs_xml_name()) == 0) {
      read_outputs_XML(node, delayeds);
    } else if (xmlStrcmp(node->name, ndeter_transition::xml_name()) == 0) {
      ndeter_trans.push_back(ndeter_transition(node, lingdef));
    } 
    /*
    else if (xmlStrcmp(node->name, ndeter_synt_transition::xml_name()) == 0) {
      ndeter_synt_trans.push_back(ndeter_synt_transition(node, lingdef));
    }
    */
    node = node->next;
  }
}



void syntagm_fst::state::write_infos_XML(xmlwriter & writer, const info_map & infos) {
  writer.start_element("info_map");
  for (info_map::const_iterator it = infos.begin(); it != infos.end(); ++it) {
    it->second.write_XML(writer, it->first);
  }
  writer.end_element();
}

void syntagm_fst::state::read_infos_XML(xmlNodePtr node, ling_def * lingdef, info_map & infos) {
  
  infos.clear();

  node = node->xmlChildrenNode;
  
  while (node) {
    if (xmlStrcmp(node->name, "info") == 0) {
    
      char * text = xmlGetProp(node, "id");
      int id;
      istringstream(text) >> id;
      xmlFree(text);

      info & inf = infos[id];
      inf.read_XML(node, lingdef);
    }
    node = node->next;
  }
}

void syntagm_fst::state::write_outputs_XML(xmlwriter & writer, const outputs_type & out) {
  writer.start_element("outputs");
  for (outputs_type::const_iterator it = out.begin(); it != out.end(); ++it) {
    it->write_XML(writer);
  }
  writer.end_element();
}

void syntagm_fst::state::read_outputs_XML(xmlNodePtr node, outputs_type & out) {
  out.clear();
  node = node->xmlChildrenNode;
  while (node) {
    if (xmlStrcmp(node->name, output_type::xml_name()) == 0) {
      out.insert(output_type(node));
    }
    node = node->next;
  }
}

void syntagm_fst::state::read_XML(xmlNodePtr node, ling_def * lingdef) {

  final_outputs.clear();
  infos.clear();
  trans.clear();
  synt_trans.clear();


  char * text = xmlGetProp(node, "final");
  istringstream(text) >> final;
  xmlFree(text);

  node = node->xmlChildrenNode;

  while (node) {
 
    if (xmlStrcmp(node->name, "outputs") == 0) {
    
      read_outputs_XML(node, final_outputs);
    
    } else if (xmlStrcmp(node->name, "infos_map") == 0) {
      
      read_infos_XML(node, lingdef, infos);
    
    } else if (xmlStrcmp(node->name, "transition") == 0) {

      int size = trans.size();

      try {

        trans.resize(size + 1);
        trans[size].read_XML(node, lingdef);

      } catch (xml_parse_error & e) {

        cerr << "unable to load transition: " << e.what() << '\n';
        trans.resize(size);
      }

    } else if (xmlStrcmp(node->name, "synt_transition") == 0) {
    
      int size = synt_trans.size();

      try {

        //synt_trans.resize(size + 1);
        synt_trans.push_back(synt_transition(node, lingdef));

      } catch (xml_parse_error & e) {

        cerr << "unable to load transition: " << e.what() << '\n';
      }
    }

    node = node->next;
  }
}

void syntagm_fst::state::write_XML(xmlwriter & writer, int q) const {

  writer.start_element("state");
  writer.write_attribute("id", lexical_cast<string>(q));
  if (final) {
    writer.write_attribute("final", "1");
  }

  write_outputs_XML(writer, final_outputs);
  write_infos_XML(writer, infos);

  for (transitions::const_iterator it = trans.begin(); it != trans.end(); ++it) {
    it->write_XML(writer);
  }

  for (synt_transitions::const_iterator it = synt_trans.begin(); it != synt_trans.end(); ++it) {
    it->write_XML(writer);
  }
  writer.end_element();
}


void syntagm_fst::read_idmap_XML(xmlNodePtr node, stateid & id, int & q) {
 
  char * text = (char *) xmlGetProp(node, (const xmlChar *) "q");
  istringstream(text) >> q;
  xmlFree(text);

  id.clear();

  node = node->xmlChildrenNode;

  while (node) {
    if (xmlStrcmp(node->name, (const xmlChar *) "position") == 0) {
      id.insert(position(node));
    }
    node = node->next;
  }
}


void syntagm_fst::write_idmap_XML(xmlwriter & writer, const stateid & id, int q) const {

  writer.start_element("state_id");
  writer.write_attribute("q", lexical_cast<string>(q));
  for (stateid::const_iterator it = id.begin(); it != id.end(); ++it) {
    it->write_XML(writer);
  }
  writer.end_element();
}

void syntagm_fst::write_XML(xmlwriter & writer) const {

  writer.start_element("syntagm_fst");
  writer.write_attribute("name", name);
  A.write_XML(writer);

  for (int q = 0; q < states.size(); ++q) {
    states[q].write_XML(writer, q);
  }

  for (state_map::const_iterator it = id2state.begin(); it != id2state.end(); ++it) {
    write_idmap_XML(writer, it->first, it->second);
  }
  writer.end_element();
}

void syntagm_fst::read_XML(xmlNodePtr node, ling_def * ldef) {
  
  clear();

  lingdef = ldef;

  if (xmlStrcmp(node->name, "syntagm_pattern") == 0) {

//    cerr << "syntagm_fst:: read_XML: from syntagm_pattern\n";
    A.read_XML(node, lingdef);
    
    name = A.get_name();

    // create initial state...
    stateid id;
    id.insert(position(A.start(), output_type::one()));
    add_state(id);
    // we're done
    return;
  }

  char * text = (char *) xmlGetProp(node, (const xmlChar *) "name");
  name = text;
  xmlFree(text);


  if (text = (char *) xmlGetProp(node, (const xmlChar *) "size")) {
    int size;
    istringstream(text) >> size;
    states.reserve(size);
    xmlFree(text);
  }

  node = node->xmlChildrenNode;

  int qno = 0;
  while (node) {

    if (xmlStrcmp(node->name, syntagm_pattern::xml_name()) == 0) {
    
      A.read_XML(node, lingdef);

    } else if (xmlStrcmp(node->name, (const xmlChar *) "state") == 0) {

      states.resize(qno + 1);
      states[qno].read_XML(node, lingdef);
      qno++;
    
    } else if (xmlStrcmp(node->name, idmap_xml_name()) == 0) {

      stateid id; int q;
      read_idmap_XML(node, id, q);
      id2state[id] = q; 
    }

    node = node->next;
  }
}

