dawn-bench-models/tensorflow/SQuAD/squad/eda_aug_dev.ipynb

272 lines
6.9 KiB
Text
Raw Normal View History

2017-08-17 12:43:17 -06:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import json\n",
"\n",
"aug_data_path = \"/Users/minjoons/data/squad/dev-v1.0-aug.json\"\n",
"aug_data = json.load(open(aug_data_path, 'r'))"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(['Denver', 'Broncos'], 'Denver Broncos')\n",
"(['Denver', 'Broncos'], 'Denver Broncos')\n",
"(['Denver', 'Broncos'], 'Denver Broncos ')\n",
"(['Carolina', 'Panthers'], 'Carolina Panthers')\n"
]
}
],
"source": [
"def compare_answers():\n",
" for article in aug_data['data']:\n",
" for para in article['paragraphs']:\n",
" deps = para['deps']\n",
" nodess = []\n",
" for dep in deps:\n",
" nodes, edges = dep\n",
" if dep is not None:\n",
" nodess.append(nodes)\n",
" else:\n",
" nodess.append([])\n",
" wordss = [[node[0] for node in nodes] for nodes in nodess]\n",
" for qa in para['qas']:\n",
" for answer in qa['answers']:\n",
" text = answer['text']\n",
" word_start = answer['answer_word_start']\n",
" word_stop = answer['answer_word_stop']\n",
" answer_words = wordss[word_start[0]][word_start[1]:word_stop[1]]\n",
" yield answer_words, text\n",
"\n",
"ca = compare_answers()\n",
"print(next(ca))\n",
"print(next(ca))\n",
"print(next(ca))\n",
"print(next(ca))"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"8\n"
]
}
],
"source": [
"def counter():\n",
" count = 0\n",
" for article in aug_data['data']:\n",
" for para in article['paragraphs']:\n",
" deps = para['deps']\n",
" nodess = []\n",
" for dep in deps:\n",
" if dep is None:\n",
" count += 1\n",
" print(count)\n",
"counter()\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0\n"
]
}
],
"source": [
"def bad_node_counter():\n",
" count = 0\n",
" for article in aug_data['data']:\n",
" for para in article['paragraphs']:\n",
" sents = para['sents']\n",
" deps = para['deps']\n",
" nodess = []\n",
" for dep in deps:\n",
" if dep is not None:\n",
" nodes, edges = dep\n",
" for node in nodes:\n",
" if len(node) != 5:\n",
" count += 1\n",
" print(count)\n",
"bad_node_counter() "
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7\n"
]
}
],
"source": [
"def noanswer_counter():\n",
" count = 0\n",
" for article in aug_data['data']:\n",
" for para in article['paragraphs']:\n",
" deps = para['deps']\n",
" nodess = []\n",
" for dep in deps:\n",
" if dep is not None:\n",
" nodes, edges = dep\n",
" nodess.append(nodes)\n",
" else:\n",
" nodess.append([])\n",
" wordss = [[node[0] for node in nodes] for nodes in nodess]\n",
" for qa in para['qas']:\n",
" for answer in qa['answers']:\n",
" text = answer['text']\n",
" word_start = answer['answer_word_start']\n",
" word_stop = answer['answer_word_stop']\n",
" if word_start is None:\n",
" count += 1\n",
" print(count)\n",
"noanswer_counter()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10600\n"
]
}
],
"source": [
"print(sum(len(para['qas']) for a in aug_data['data'] for para in a['paragraphs']))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10348\n"
]
}
],
"source": [
"import nltk\n",
"\n",
"def _set_span(t, i):\n",
" if isinstance(t[0], str):\n",
" t.span = (i, i+len(t))\n",
" else:\n",
" first = True\n",
" for c in t:\n",
" cur_span = _set_span(c, i)\n",
" i = cur_span[1]\n",
" if first:\n",
" min_ = cur_span[0]\n",
" first = False\n",
" max_ = cur_span[1]\n",
" t.span = (min_, max_)\n",
" return t.span\n",
"\n",
"\n",
"def set_span(t):\n",
" assert isinstance(t, nltk.tree.Tree)\n",
" try:\n",
" return _set_span(t, 0)\n",
" except:\n",
" print(t)\n",
" exit()\n",
"\n",
"def same_span_counter():\n",
" count = 0\n",
" for article in aug_data['data']:\n",
" for para in article['paragraphs']:\n",
" consts = para['consts']\n",
" for const in consts:\n",
" tree = nltk.tree.Tree.fromstring(const)\n",
" set_span(tree)\n",
" if len(list(tree.subtrees())) > len(set(t.span for t in tree.subtrees())):\n",
" count += 1\n",
" print(count)\n",
"same_span_counter()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}