# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from absl import app
from absl import flags
import matplotlib
import os

matplotlib.use('TkAgg')
import matplotlib.pyplot as plt

plt.style.use('ggplot')

FLAGS = flags.FLAGS
flags.DEFINE_string('plot_file', '', 'Output file name.')

qa_lnmax = [500, 750] + range(1000, 12500, 500)

acc_lnmax = [43.3, 52.3, 59.8, 66.7, 68.8, 70.5, 71.6, 72.3, 72.6, 72.9, 73.4,
             73.4, 73.7, 73.9, 74.2, 74.4, 74.5, 74.7, 74.8, 75, 75.1, 75.1,
             75.4, 75.4, 75.4]

qa_gnmax = [456, 683, 908, 1353, 1818, 2260, 2702, 3153, 3602, 4055, 4511, 4964,
            5422, 5875, 6332, 6792, 7244, 7696, 8146, 8599, 9041, 9496, 9945,
            10390, 10842]

acc_gnmax = [39.6, 52.2, 59.6, 66.6, 69.6, 70.5, 71.8, 72, 72.7, 72.9, 73.3,
             73.4, 73.4, 73.8, 74, 74.2, 74.4, 74.5, 74.5, 74.7, 74.8, 75, 75.1,
             75.1, 75.4]

qa_gnmax_aggressive = [167, 258, 322, 485, 647, 800, 967, 1133, 1282, 1430,
                       1573, 1728, 1889, 2028, 2190, 2348, 2510, 2668, 2950,
                       3098, 3265, 3413, 3581, 3730]

acc_gnmax_aggressive = [17.8, 26.8, 39.3, 48, 55.7, 61, 62.8, 64.8, 65.4, 66.7,
                        66.2, 68.3, 68.3, 68.7, 69.1, 70, 70.2, 70.5, 70.9,
                        70.7, 71.3, 71.3, 71.3, 71.8]


def main(argv):
  del argv  # Unused.

  plt.close('all')
  fig, ax = plt.subplots()
  fig.set_figheight(4.7)
  fig.set_figwidth(5)
  ax.plot(qa_lnmax, acc_lnmax, color='r', ls='--', linewidth=5., marker='o',
          alpha=.5, label='LNMax')
  ax.plot(qa_gnmax, acc_gnmax, color='g', ls='-', linewidth=5., marker='o',
          alpha=.5, label='Confident-GNMax')
  # ax.plot(qa_gnmax_aggressive, acc_gnmax_aggressive, color='b', ls='-', marker='o', alpha=.5, label='Confident-GNMax (aggressive)')
  plt.xticks([0, 2000, 4000, 6000])
  plt.xlim([0, 6000])
  # ax.set_yscale('log')
  plt.ylim([65, 76])
  ax.tick_params(labelsize=14)
  plt.xlabel('Number of queries answered', fontsize=16)
  plt.ylabel('Student test accuracy (%)', fontsize=16)
  plt.legend(loc=2, prop={'size': 16})

  x = [400, 2116, 4600, 4680]
  y = [69.5, 68.5, 74, 72.5]
  annotations = [0.76, 2.89, 1.42, 5.76]
  color_annotations = ['g', 'r', 'g', 'r']
  for i, txt in enumerate(annotations):
    ax.annotate(r'${\varepsilon=}$' + str(txt), (x[i], y[i]), fontsize=16,
                color=color_annotations[i])

  plot_filename = os.path.expanduser(FLAGS.plot_file)
  plt.savefig(plot_filename, bbox_inches='tight')
  plt.show()

if __name__ == '__main__':
  app.run(main)