{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "aug_data_path = \"/Users/minjoons/data/squad/dev-v1.0-aug.json\"\n",
    "aug_data = json.load(open(aug_data_path, 'r'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(['Denver', 'Broncos'], 'Denver Broncos')\n",
      "(['Denver', 'Broncos'], 'Denver Broncos')\n",
      "(['Denver', 'Broncos'], 'Denver Broncos ')\n",
      "(['Carolina', 'Panthers'], 'Carolina Panthers')\n"
     ]
    }
   ],
   "source": [
    "def compare_answers():\n",
    "    for article in aug_data['data']:\n",
    "        for para in article['paragraphs']:\n",
    "            deps = para['deps']\n",
    "            nodess = []\n",
    "            for dep in deps:\n",
    "                nodes, edges = dep\n",
    "                if dep is not None:\n",
    "                    nodess.append(nodes)\n",
    "                else:\n",
    "                    nodess.append([])\n",
    "            wordss = [[node[0] for node in nodes] for nodes in nodess]\n",
    "            for qa in para['qas']:\n",
    "                for answer in qa['answers']:\n",
    "                    text = answer['text']\n",
    "                    word_start = answer['answer_word_start']\n",
    "                    word_stop = answer['answer_word_stop']\n",
    "                    answer_words = wordss[word_start[0]][word_start[1]:word_stop[1]]\n",
    "                    yield answer_words, text\n",
    "\n",
    "ca = compare_answers()\n",
    "print(next(ca))\n",
    "print(next(ca))\n",
    "print(next(ca))\n",
    "print(next(ca))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "8\n"
     ]
    }
   ],
   "source": [
    "def counter():\n",
    "    count = 0\n",
    "    for article in aug_data['data']:\n",
    "        for para in article['paragraphs']:\n",
    "            deps = para['deps']\n",
    "            nodess = []\n",
    "            for dep in deps:\n",
    "                if dep is None:\n",
    "                    count += 1\n",
    "    print(count)\n",
    "counter()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    }
   ],
   "source": [
    "def bad_node_counter():\n",
    "    count = 0\n",
    "    for article in aug_data['data']:\n",
    "        for para in article['paragraphs']:\n",
    "            sents = para['sents']\n",
    "            deps = para['deps']\n",
    "            nodess = []\n",
    "            for dep in deps:\n",
    "                if dep is not None:\n",
    "                    nodes, edges = dep\n",
    "                    for node in nodes:\n",
    "                        if len(node) != 5:\n",
    "                            count += 1\n",
    "    print(count)\n",
    "bad_node_counter()  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7\n"
     ]
    }
   ],
   "source": [
    "def noanswer_counter():\n",
    "    count = 0\n",
    "    for article in aug_data['data']:\n",
    "        for para in article['paragraphs']:\n",
    "            deps = para['deps']\n",
    "            nodess = []\n",
    "            for dep in deps:\n",
    "                if dep is not None:\n",
    "                    nodes, edges = dep\n",
    "                    nodess.append(nodes)\n",
    "                else:\n",
    "                    nodess.append([])\n",
    "            wordss = [[node[0] for node in nodes] for nodes in nodess]\n",
    "            for qa in para['qas']:\n",
    "                for answer in qa['answers']:\n",
    "                    text = answer['text']\n",
    "                    word_start = answer['answer_word_start']\n",
    "                    word_stop = answer['answer_word_stop']\n",
    "                    if word_start is None:\n",
    "                        count += 1\n",
    "    print(count)\n",
    "noanswer_counter()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10600\n"
     ]
    }
   ],
   "source": [
    "print(sum(len(para['qas']) for a in aug_data['data'] for para in a['paragraphs']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10348\n"
     ]
    }
   ],
   "source": [
    "import nltk\n",
    "\n",
    "def _set_span(t, i):\n",
    "    if isinstance(t[0], str):\n",
    "        t.span = (i, i+len(t))\n",
    "    else:\n",
    "        first = True\n",
    "        for c in t:\n",
    "            cur_span = _set_span(c, i)\n",
    "            i = cur_span[1]\n",
    "            if first:\n",
    "                min_ = cur_span[0]\n",
    "                first = False\n",
    "        max_ = cur_span[1]\n",
    "        t.span = (min_, max_)\n",
    "    return t.span\n",
    "\n",
    "\n",
    "def set_span(t):\n",
    "    assert isinstance(t, nltk.tree.Tree)\n",
    "    try:\n",
    "        return _set_span(t, 0)\n",
    "    except:\n",
    "        print(t)\n",
    "        exit()\n",
    "\n",
    "def same_span_counter():\n",
    "    count = 0\n",
    "    for article in aug_data['data']:\n",
    "        for para in article['paragraphs']:\n",
    "            consts = para['consts']\n",
    "            for const in consts:\n",
    "                tree = nltk.tree.Tree.fromstring(const)\n",
    "                set_span(tree)\n",
    "                if len(list(tree.subtrees())) > len(set(t.span for t in tree.subtrees())):\n",
    "                    count += 1\n",
    "    print(count)\n",
    "same_span_counter()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}