import BaseAgent from "./baseAgent";
import Tryonumjs from "../tryonumjs";

/**
 * Select the price (action or bandit arm) that has given greatest avg revenue,
 * and try a random action with probability epsilon.
 */
export default class SoftmaxGreedyAgent extends BaseAgent {
  constructor(
    prices,
    temperature = 0.1,
    initialExploration = false,
    adjacentOnly = false
  ) {
    super(prices);
    this.initialExploration = initialExploration;
    this.adjacentOnly = adjacentOnly;
    this.temperature = temperature;
    this.reset();
  }

  reset() {
    this.currentPriceIdx = 0;
    this.exploredPrices = 0;
    this.meanPriceSales = Tryonumjs.zeros(this.nPrices);
    this.meanTotalSales = 0;
    this.counts = Tryonumjs.zeros(this.nPrices);
  }

  getName() {
    return (
      `SoftmaxGreedy (` +
      `initExplore = ${this.initialExploration}, ` +
      `adjacentOnly = ${this.adjacentOnly}` +
      `)`
    );
  }

  /**
   * This is the action to take, and  may be evaluated every certain number of steps
   * or every step, the algorithm should work either way.
   */
  _selectPrice() {
    if (this.initialExploration && this.exploredPrices < this.nPrices) {
      this.currentPriceIdx = this.exploredPrices;
      this.exploredPrices += 1;
    } else {
      // console.log(`Price sales ${this.meanPriceSales}`);
      // Softmax exploration
      let priceProfits = Tryonumjs.substract(this.prices, this.productCost);
      let meanProfit = Tryonumjs.sum(priceProfits) / priceProfits.length;
      let totalPriceProfits = Tryonumjs.multiply(
        priceProfits,
        this.meanPriceSales
      );
      // console.log(`Price profit ${totalPriceProfits}`);
      let meanTotalProfit = this.meanTotalSales * meanProfit;
      let exps = Tryonumjs.exp(
        Tryonumjs.divide(
          Tryonumjs.substract(totalPriceProfits, meanTotalProfit),
          this.temperature * meanTotalProfit
        )
      );
      // hardcode to 0 the probability to select prices with negative profits
      for (let i = 0; i < priceProfits.length; i++) {
        if (priceProfits[i] <= 0) {
          exps[i] = 0;
        }
      }
      let probs = Tryonumjs.divide(exps, Tryonumjs.sum(exps));
      // console.log(`Probs: ${probs}`);

      if (this.adjacentOnly) {
        var elegiblePrices = [this.currentPriceIdx];
        if (this.currentPriceIdx > 0) {
          elegiblePrices.push(this.currentPriceIdx - 1);
        }
        if (this.currentPriceIdx < this.nPrices - 1) {
          elegiblePrices.push(this.currentPriceIdx + 1);
        }
        var elegibleProbs = elegiblePrices.map(i => probs[i]);
        elegibleProbs = Tryonumjs.divide(
          elegibleProbs,
          Tryonumjs.sum(elegibleProbs)
        );

        this.currentPriceIdx = Tryonumjs.randomChoice(
          elegiblePrices,
          elegibleProbs
        );
      } else {
        this.currentPriceIdx = Tryonumjs.randomChoice(this.pricesIdx, probs);
      }
    }
    return this.prices[this.currentPriceIdx];
  }

  /**
   * Feedback for the agent.
   * This is intended to be called on every step(or accumulated otherwise), regardless
   * of how often selectPrice() is called.
   */
  _reward(reward) {}

  _updateSales(sales) {
    this.counts[this.currentPriceIdx] += 1;
    const meanPriceSales = this.meanPriceSales[this.currentPriceIdx];

    const n_a = this.counts[this.currentPriceIdx];
    this.meanPriceSales[this.currentPriceIdx] += (sales - meanPriceSales) / n_a;
    this.meanTotalSales +=
      (sales - this.meanTotalSales) / Tryonumjs.sum(this.counts);
  }
}
