dawn-bench-models/tensorflow/SQuAD/tree/test.ipynb

295 lines
6.3 KiB
Text
Raw Normal View History

2017-08-17 12:43:17 -06:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import nltk\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(S (PRP I) (VP (VBP am) (NNP Sam)) (. .))\n",
"(PRP I)\n",
"(VP (VBP am) (NNP Sam))\n",
"(VBP am)\n",
"(NNP Sam)\n",
"(. .)\n",
"(S (PRP I) (VP (VBP am) (NNP Sam)) (. .))\n"
]
}
],
"source": [
"string = \"(ROOT(S(NP (PRP I))(VP (VBP am)(NP (NNP Sam)))(. .)))\"\n",
"tree = nltk.tree.Tree.fromstring(string)\n",
"\n",
"def load_compressed_tree(s):\n",
"\n",
" def compress_tree(tree):\n",
" if len(tree) == 1:\n",
" if isinstance(tree[0], nltk.tree.Tree):\n",
" return compress_tree(tree[0])\n",
" else:\n",
" return tree\n",
" else:\n",
" for i, t in enumerate(tree):\n",
" tree[i] = compress_tree(t)\n",
" return tree\n",
"\n",
" return compress_tree(nltk.tree.Tree.fromstring(s))\n",
"tree = load_compressed_tree(string)\n",
"for t in tree.subtrees():\n",
" print(t)\n",
" \n",
"print(str(tree))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(ROOT I am Sam .)\n"
]
}
],
"source": [
"print(tree.flatten())"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['ROOT', 'S', 'NP', 'PRP', 'VP', 'VBP', 'NP', 'NNP', '.']\n"
]
}
],
"source": [
"print(list(t.label() for t in tree.subtrees()))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"import json\n",
"d = json.load(open(\"data/squad/shared_dev.json\", 'r'))"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"73"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(d['pos_counter'])"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"collapsed": false
},
"outputs": [
{
"data": {
"text/plain": [
"{'#': 6,\n",
" '$': 80,\n",
" \"''\": 1291,\n",
" ',': 14136,\n",
" '-LRB-': 1926,\n",
" '-RRB-': 1925,\n",
" '.': 9505,\n",
" ':': 1455,\n",
" 'ADJP': 3426,\n",
" 'ADVP': 4936,\n",
" 'CC': 9300,\n",
" 'CD': 6216,\n",
" 'CONJP': 191,\n",
" 'DT': 26286,\n",
" 'EX': 288,\n",
" 'FRAG': 107,\n",
" 'FW': 96,\n",
" 'IN': 32564,\n",
" 'INTJ': 12,\n",
" 'JJ': 21452,\n",
" 'JJR': 563,\n",
" 'JJS': 569,\n",
" 'LS': 7,\n",
" 'LST': 1,\n",
" 'MD': 1051,\n",
" 'NAC': 19,\n",
" 'NN': 34750,\n",
" 'NNP': 28392,\n",
" 'NNPS': 1400,\n",
" 'NNS': 16716,\n",
" 'NP': 91636,\n",
" 'NP-TMP': 236,\n",
" 'NX': 108,\n",
" 'PDT': 89,\n",
" 'POS': 1451,\n",
" 'PP': 33278,\n",
" 'PRN': 2085,\n",
" 'PRP': 2320,\n",
" 'PRP$': 1959,\n",
" 'PRT': 450,\n",
" 'QP': 838,\n",
" 'RB': 7611,\n",
" 'RBR': 301,\n",
" 'RBS': 252,\n",
" 'ROOT': 9587,\n",
" 'RP': 454,\n",
" 'RRC': 19,\n",
" 'S': 21557,\n",
" 'SBAR': 5009,\n",
" 'SBARQ': 6,\n",
" 'SINV': 135,\n",
" 'SQ': 5,\n",
" 'SYM': 17,\n",
" 'TO': 5167,\n",
" 'UCP': 143,\n",
" 'UH': 15,\n",
" 'VB': 4197,\n",
" 'VBD': 8377,\n",
" 'VBG': 3570,\n",
" 'VBN': 7218,\n",
" 'VBP': 2897,\n",
" 'VBZ': 4146,\n",
" 'VP': 33696,\n",
" 'WDT': 1368,\n",
" 'WHADJP': 5,\n",
" 'WHADVP': 439,\n",
" 'WHNP': 1927,\n",
" 'WHPP': 153,\n",
" 'WP': 482,\n",
" 'WP$': 50,\n",
" 'WRB': 442,\n",
" 'X': 23,\n",
" '``': 1269}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"d['pos_counter']"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[False False False False]\n",
" [False True False False]\n",
" [False False False False]]\n",
"[[0 2 2 0]\n",
" [2 2 0 2]\n",
" [2 0 0 0]]\n"
]
}
],
"source": [
"from my.nltk_utils import tree2matrix, load_compressed_tree, find_max_f1_subtree, set_span\n",
"string = \"(ROOT(S(NP (PRP I))(VP (VBP am)(NP (NNP Sam)))(. .)))\"\n",
"tree = load_compressed_tree(string)\n",
"span = (1, 3)\n",
"set_span(tree)\n",
"subtree = find_max_f1_subtree(tree, span)\n",
"f = lambda t: t == subtree\n",
"g = lambda t: 1 if isinstance(t, str) else 2\n",
"a, b = tree2matrix(tree, f, dtype='bool')\n",
"c, d = tree2matrix(tree, g, dtype='int32')\n",
"print(a)\n",
"print(c)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.1"
}
},
"nbformat": 4,
"nbformat_minor": 0
}