/* quiesce.cc
 */
#include "quiesce.h"
#include "osl/search/sortCaptureMoves.h"
#include "osl/search/shouldPromoteCut.h"
#include "osl/search/quiescenceGenerator.h"
#include "osl/move_order/captureSort.h"
#include "osl/checkmate/immediateCheckmate.h"
#include "osl/eval/evalTraits.h"
#include "osl/eval/pieceEval.h"
#include "osl/effect_util/effectUtil.h"
#include "osl/move_generator/capture_.h"
#include "osl/move_generator/promote_.h"
#include "osl/move_generator/escape_.h"
#include "osl/move_generator/capture_.h"
#include "osl/move_generator/allMoves.h"
#include "osl/move_action/store.h"
#include "osl/move_classifier/check_.h"
#include "osl/rating/ratedMoveVector.h"
#include <boost/foreach.hpp>
#include <iostream>

const size_t NoDebug = (size_t)-1;

// #difeni DEBUG_QUIESCE
#ifdef DEBUG_QUIESCE
const size_t debug_time = NoDebug;
// const size_t debug_time = NoDebug-1;
// const size_t debug_time = 240;
#endif

#define LEARN_PVS

std::ostream& 
gpsshogi::operator<<(std::ostream& os, 
		     const gpsshogi::PVVector& pv)
{
  BOOST_FOREACH(Move m, pv)
    os << m << ' ';
  return os << "\n";
}

gpsshogi::
Table::Table()
{
}

gpsshogi::
Table::~Table()
{
}

gpsshogi::Record *gpsshogi::
Table::allocate(const HashKey& key)
{
  return &table[key];
}

void gpsshogi::
Table::clear()
{
  table.clear();
}

/* ------------------------------------------------------------------------- */

gpsshogi::
Quiesce::Quiesce(Eval *e, int a, int q) 
  : eval(e), all_moves_depth(a), quiesce_depth(q),
    root_depth_left(0), time(0), pv(64), node_count(0)
{
}

gpsshogi::
Quiesce::~Quiesce()
{
}

void gpsshogi::
Quiesce::clear()
{
  node_count = 0;
  table.clear();
}
    
bool gpsshogi::
Quiesce::quiesce(NumEffectState& state,
		 int& value, PVVector& pv)
{
  return quiesce(state, value, pv, -infty(BLACK), infty(BLACK));
}

bool gpsshogi::
Quiesce::quiesce(NumEffectState& state,
		 int& value, PVVector& pv, int alpha, int beta)
{
  assert(alpha != beta);
  history = pv;

  const Player turn = state.getTurn();
  if (eval::betterThan(turn, alpha, beta))
    std::swap(alpha, beta);

  key = HashKey(state);
  history_state.setRoot(state);
  eval_value.reset(eval->newStack(state));

  root_depth_left = all_moves_depth + quiesce_depth;
  value = search(alpha, beta, root_depth_left);
  assert(history.size() < this->pv.size());
  pv = this->pv[history.size()];

  return true;
}

void gpsshogi::
Quiesce::selectSeePlus(const NumEffectState& state, const MoveVector& src, MoveVector& dst,
		       int threshold) const
{
  RatedMoveVector moves;
  for (size_t i=0; i<src.size(); ++i) {
    const int see = osl::eval::PieceEval::computeDiffAfterMoveForRP(state, src[i]);
    if (see < threshold)
      continue;
    moves.push_back(RatedMove(src[i], see));
  }
  moves.sort();
  for (size_t i=0; i<moves.size(); ++i)
    dst.push_back(moves[i].move());
}

void gpsshogi::
Quiesce::generateTacticalMoves(MoveVector& out) const
{
  const Player turn = state()->getTurn();
  MoveVector moves;
  {
    move_action::Store store(moves);
    // capture
    for (int i=0; i<Piece::SIZE; ++i) {
      const Piece piece = state()->getPieceOf(i);
      if (! piece.isOnBoardByOwner(alt(turn)))
	continue;

      move_generator::GenerateCapture::generate(turn, *state(), piece.position(), 
						store);
    }
    // promote
    move_generator::GeneratePromote<true>::generate(turn, *state(), store);    
  }
  selectSeePlus(*state(), moves, out);
}

void gpsshogi::
Quiesce::generateAllMoves(MoveVector& moves) const
{
  GenerateAllMoves::generate(state()->getTurn(), *state(), moves);
  search::SortCaptureMoves::sortByTakeBack(*state(), moves);
}
void gpsshogi::
Quiesce::generateAllMovesSeePlus(MoveVector& moves, int threshold) const
{
  MoveVector tmp;
  GenerateAllMoves::generate(state()->getTurn(), *state(), tmp);
  selectSeePlus(*state(), tmp, moves, threshold);
}

void gpsshogi::
Quiesce::generateTakeBack(MoveVector& out, Position to) const
{
  if (to.isPieceStand())
    return;
  MoveVector moves;
  {
    move_action::Store store(moves);
    move_generator::GenerateCapture::generate(state()->getTurn(), *state(), to, store);
  }

  selectSeePlus(*state(), moves, out);
}

void gpsshogi::
Quiesce::generateEscapeFromLastMove(MoveVector& moves, Move last_move) const
{
  typedef osl::PieceEval eval_t;
  MoveVector all;
  if (state()->getTurn() == BLACK)
    search::QuiescenceGenerator<BLACK>::escapeFromLastMove<eval_t>(*state(), last_move, all);
  else
    search::QuiescenceGenerator<WHITE>::escapeFromLastMove<eval_t>(*state(), last_move, all);
  for (size_t i=0; i<std::min(all.size(), (size_t)2); ++i)
    moves.push_back(all[i]);
}

int gpsshogi::
Quiesce::search(int alpha, int beta, int depth_left)
{
  ++node_count;
  const bool in_pv = alpha != beta;
  const size_t my_time = ++time;
#ifdef DEBUG_QUIESCE
  if (my_time == debug_time) {
    Record *record = table.allocate(key);
    std::cerr << "debug_time node " << my_time << (in_pv ? " pv" : "") << "\n"
	      << *state() << "[" << alpha << " " << beta << "] " << depth_left << "\n";
    std::cerr << "record [" << record->lower_bound << " " << record->lower_depth 
	      << "  " << record->upper_bound << " " << record->upper_depth << "] at "
	      << record->update_time << " " << record->best_move << "\n";
  }
  if (debug_time != NoDebug && time >= debug_time+1)
    std::cerr << "\nnew node " << (in_pv ? "pv " : "") << depth_left << " t " 
	      << my_time << " [" << alpha << " " << beta << "] " << history;
#endif
  const bool in_full_search = root_depth_left - all_moves_depth < depth_left;

  const Player turn = state()->getTurn();
  assert(! eval::betterThan(turn, alpha, beta));
  if (state()->inCheck(alt(turn)))
    return infty(turn);

  const int initial_alpha = alpha;
  int best_value = -infty(turn);
  const bool is_king_in_check = state()->inCheck();
  Move best_move;
  assert(history.size() < this->pv.size());
  PVVector& cur_pv = pv[history.size()];
  cur_pv = history;
  // TODO: tsumero?
  const int stand_pat = eval_value->value();
  if (! is_king_in_check && ! in_full_search) {
    // assert(eval_value->value() == eval->eval(*state()));
#ifdef DEBUG_QUIESCE
    if (my_time == debug_time)
      std::cerr << "stand_pat " << stand_pat << "\n";
#endif
    if (eval::notLessThan(turn, stand_pat, beta)
	|| history.size()+1 >= history.capacity()) {
      return stand_pat;
    }
    if (eval::betterThan(turn, stand_pat, best_value)) {
      best_value = stand_pat;
      best_move = Move::PASS(turn);
      alpha = eval::max(turn, alpha, best_value);
    }
  }

  // table
  Record *record = table.allocate(key);
  if (! in_pv) {
    int previous_visit = std::max(record->lower_depth, record->upper_depth);
    if (previous_visit > depth_left)
      return infty(turn);
  }
  const int checkmate_special_depth = 100;

  if (! in_pv && record->lower_depth >= depth_left) {
    if (eval::betterThan(turn, record->lower_bound, best_value)) {
      best_move = record->best_move;
#ifdef DEBUG_QUIESCE
      if (debug_time != NoDebug)
	std::cerr << "lower_bound hit " << record->lower_bound << " [" << alpha << " " << beta 
		  << "] at " << my_time << " " << record->update_time << "\n";
#endif
      if (eval::betterThan(turn, record->lower_bound, alpha)) {
	if (eval::notLessThan(turn, record->lower_bound, beta)) {
	  return record->lower_bound;
	}
	alpha = record->lower_bound;
      }
      best_value = record->lower_bound;
    }
  }

  Move checkmate_move;
  if (! is_king_in_check
      && ImmediateCheckmate::hasCheckmateMove(turn, *state(), checkmate_move)) {
    record->best_move = checkmate_move;
    record->setValue(my_time, checkmate_special_depth, infty(turn));
    cur_pv.push_back(checkmate_move);
#ifdef DEBUG_QUIESCE
    if (my_time == debug_time)
      std::cerr << "checkmate at " << my_time << " " << record->lower_bound << " " << record->lower_depth
		<< " " << record->best_move << "\npv " << cur_pv;
#endif
    return infty(turn);
  }

  if ((! in_pv) && (record->upper_depth >= depth_left)) {
    if (eval::betterThan(turn, beta, record->upper_bound)) {
#ifdef DEBUG_QUIESCE
      if (debug_time != NoDebug)
	std::cerr << "upper_bound hit " << record->upper_bound << " [" << alpha << " " << beta 
		  << "] at " << my_time << " " << record->update_time << " in_pv " << in_pv << "\n";
#endif
      if (eval::notLessThan(turn, alpha, record->upper_bound)) {
	return record->upper_bound;
      }
      beta = record->upper_bound;
    }
  }

  // search specials
  FixedCapacityVector<Move,2> first_moves;
  if (record->best_move.isNormal())
    if (in_full_search 
	|| (depth_left > 0 && record->best_move.capturePtype() != PTYPE_EMPTY))
      first_moves.push_back(record->best_move);
#ifdef LEARN_DEEP_SEARCH
  if (! state()->inCheck() && in_full_search && eval::betterThan(turn, stand_pat, alpha))
    first_moves.push_back(Move::PASS(turn));
#endif
  for (size_t i=0; i<first_moves.size(); ++i) {
    const Move m = first_moves[i];
    const int reduction = m.isPass() ? 2 : 1;
    assert(eval::notLessThan(turn, alpha, best_value));
    assert(eval::notLessThan(turn, beta, alpha));

    int value;
    {
      const HashKey old_hash = key;
      DoUndoMoveLock lock(history_state, m);
      history.push_back(m);

      key = old_hash.newHashWithMove(m);
      eval_value->push(*state(), m);
#ifdef LEARN_PVS
      if (in_pv && i > 0 && root_depth_left == depth_left) 
      {
	value = search(alpha+eval::delta(turn), alpha, depth_left-reduction);
	if (eval::betterThan(turn, value, alpha)) {
	  value = search(beta, alpha, depth_left-reduction);
	}
      }
      else 
#endif
      {
	value = search(beta, alpha, depth_left-reduction);
      }
      key = old_hash;
      eval_value->pop();
      history.pop_back();
    }
    if (eval::betterThan(turn, value, best_value)) {
      assert(history.size()+1 < this->pv.size());
      cur_pv = pv[history.size()+1];
      if (eval::betterThan(turn, value, alpha)) {
	if (eval::notLessThan(turn, value, beta)) {
	  record->setLowerBound(my_time, depth_left, value, in_pv);
	  record->best_move = m;
	  return value;
	}
	alpha = value;
      }
      best_value = value;
      best_move = m;
    }
  }

  // move generation
  MoveVector moves;
  if (is_king_in_check) {
    GenerateEscapeKing::generateCheap(*state(), moves);
    if (moves.empty()) {
      record->setValue(my_time, checkmate_special_depth, -infty(turn));
      return -infty(turn);
    }
    move_order::CaptureSort::sort(moves.begin(), moves.end());
  }
  else {
    if (in_full_search) {
#ifdef LEARN_DEEP_SEARCH
      generateAllMovesSeePlus(moves, -200*(all_moves_depth - (root_depth_left - depth_left)));
#else
      generateAllMoves(moves);
#endif
    }
    else if (depth_left > 0) {
      generateTacticalMoves(moves);
      if (root_depth_left == all_moves_depth + depth_left
	  && ! history.empty())
	generateEscapeFromLastMove(moves, history.back());
    } else {
      generateTakeBack(moves, history[history.size()-1].to());
    }
  }
#ifdef DEBUG_QUIESCE
  if (my_time == debug_time-1)
    std::cerr << "time " << my_time << " best_move " << best_move << " " << record->best_move << "\n";
#endif  
  // search all
  if (first_moves.empty())
    first_moves.push_back(Move::PASS(turn));
  const int reduction = (is_king_in_check && moves.size() == 1) ? 0 : 1;
  for (MoveVector::const_iterator p=moves.begin(); p!=moves.end(); ++p) {
    if (first_moves[0] == *p)
      continue;
    if (osl::search::ShouldPromoteCut::canIgnoreAndNotDrop(*p))
      continue;
#ifdef DEBUG_QUIESCE
    if (my_time == debug_time)
      std::cerr << "try " << *p << " " << cur_pv;
#endif
    assert(eval::notLessThan(turn, alpha, best_value));
    assert(eval::notLessThan(turn, beta, alpha));

    int value;
    {
      const HashKey old_hash = key;
      DoUndoMoveLock lock(history_state, *p);
      history.push_back(*p);

      key = old_hash.newHashWithMove(*p);
      eval_value->push(*state(), *p);
#ifdef LEARN_PVS
      if (in_pv && root_depth_left == depth_left) 
      {
	value = search(alpha+eval::delta(turn), alpha, depth_left-reduction);
	if (eval::betterThan(turn, value, alpha)) {
	  value = search(beta, alpha, depth_left-reduction);
	}
      }
      else 
#endif
      {
	value = search(beta, alpha, depth_left-reduction);
      }
#ifdef DEBUG_QUIESCE
      if (debug_time != NoDebug && time >= debug_time+1)
	std::cerr << "depth " << depth_left << " got " << value << " " << *p << " " << alpha << " " << beta << " at " 
		  << my_time << " " << time << "\n";
#endif
      key = old_hash;
      eval_value->pop();
      history.pop_back();
    }
    if (eval::betterThan(turn, value, best_value)) {
      assert(history.size()+1 < this->pv.size());
      cur_pv = pv[history.size()+1];
      if (eval::betterThan(turn, value, alpha)) {
	if (eval::notLessThan(turn, value, beta)) {
	  record->setLowerBound(my_time, depth_left, value, in_pv);
	  record->best_move = *p;
#ifdef DEBUG_QUIESCE
	  if (my_time == debug_time)
	    std::cerr << "debug_time beta cut " << value << " " << record->update_time
		      << " " << record->lower_bound << " " << record->lower_depth << "\n";
#endif
	  return value;
	}
	alpha = value;
      }
      best_value = value;
      best_move = *p;
#ifdef DEBUG_QUIESCE
      if (debug_time != NoDebug && my_time == debug_time)
	std::cerr << "time " << my_time << " best move updated " << best_move << "\n";
#endif
    }
  }
  record->setUpperBound(my_time, depth_left, best_value, in_pv);
  if (eval::betterThan(turn, best_value, initial_alpha)) {
    assert(in_pv || eval::notLessThan(turn, best_value, record->lower_bound) 
	   || record->lower_depth < depth_left);
    record->setLowerBound(my_time, depth_left, best_value, in_pv);
  }
  if (in_pv
      || record->best_move.isInvalid()) {
    record->best_move = best_move;
    record->update_time = my_time;
  }
  if (in_pv && ! (record->lower_bound == best_value || record->upper_bound == best_value)) {
    std::cerr << "time " << my_time << "\n" 
	      << *state() << history << record->lower_bound << " " << record->upper_bound
	      << " " << best_value << " " << initial_alpha << " " << alpha << " " << beta << "\n";
  }
#ifdef DEBUG_QUIESCE
  if (my_time == debug_time)
    std::cerr << "time " << my_time << " " << record->best_move << " " << best_move
	      << " " << best_value
	      << " [" << record->lower_bound << " " << record->upper_bound << "]"
	      << record->lower_depth << " " << record->upper_depth << "\npv " << cur_pv;
#endif
  assert(! in_pv || record->lower_bound == best_value || record->upper_bound == best_value);
  return best_value;
}

/* ------------------------------------------------------------------------- */
// ;;; Local Variables:
// ;;; mode:c++
// ;;; c-basic-offset:2
// ;;; End:
