📝 dinglehopper: Update Levenshtein notebook

pull/66/head
Gerber, Mike 3 years ago
parent 3ee688001a
commit 06ea38449c

@ -18,62 +18,20 @@
"# Levenshtein edit distance" "# Levenshtein edit distance"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"dinglehopper uses to have its own (very inefficient) Levenshtein edit distance implementation, but now uses RapidFuzz."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"def levenshtein_matrix(seq1, seq2):\n",
" \"\"\"Compute the matrix commonly computed to produce the Levenshtein distance.\n",
"\n",
" This is also known as the Wagner-Fischer algorithm. The matrix element at the bottom right contains the desired\n",
" edit distance.\n",
"\n",
" This algorithm is implemented here because we need an implementation that can work with sequences other than\n",
" strings, e.g. lists of grapheme clusters or lists of word strings.\n",
" \"\"\"\n",
" m = len(seq1)\n",
" n = len(seq2)\n",
"\n",
" def from_to(start, stop):\n",
" return range(start, stop + 1, 1)\n",
"\n",
" D = np.zeros((m + 1, n + 1), np.int)\n",
" D[0, 0] = 0\n",
" for i in from_to(1, m):\n",
" D[i, 0] = i\n",
" for j in from_to(1, n):\n",
" D[0, j] = j\n",
" for i in from_to(1, m):\n",
" for j in from_to(1, n):\n",
" D[i, j] = min(\n",
" D[i - 1, j - 1] + 1 * (seq1[i - 1] != seq2[j - 1]), # Same or Substitution\n",
" D[i, j - 1] + 1, # Insertion\n",
" D[i - 1, j] + 1 # Deletion\n",
" )\n",
"\n",
" return D\n",
"\n",
"def levenshtein(seq1, seq2):\n",
" \"\"\"Compute the Levenshtein edit distance between two sequences\"\"\"\n",
" m = len(seq1)\n",
" n = len(seq2)\n",
"\n",
" D = levenshtein_matrix(seq1, seq2)\n",
" return D[m, n]\n",
"\n"
]
}
],
"source": [ "source": [
"from edit_distance import levenshtein_matrix, levenshtein\n", "from rapidfuzz.string_metric import levenshtein"
"\n",
"print(inspect.getsource(levenshtein_matrix))\n",
"print(inspect.getsource(levenshtein))"
] ]
}, },
{ {
@ -170,21 +128,23 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"def distance(s1, s2):\n", "@multimethod\n",
"def distance(s1: str, s2: str):\n",
" \"\"\"Compute the Levenshtein edit distance between two Unicode strings\n", " \"\"\"Compute the Levenshtein edit distance between two Unicode strings\n",
"\n", "\n",
" Note that this is different from levenshtein() as this function knows about Unicode normalization and grapheme\n", " Note that this is different from levenshtein() as this function knows about Unicode\n",
" clusters. This should be the correct way to compare two Unicode strings.\n", " normalization and grapheme clusters. This should be the correct way to compare two\n",
" Unicode strings.\n",
" \"\"\"\n", " \"\"\"\n",
" s1 = list(grapheme_clusters(unicodedata.normalize('NFC', s1)))\n", " seq1 = list(grapheme_clusters(unicodedata.normalize(\"NFC\", s1)))\n",
" s2 = list(grapheme_clusters(unicodedata.normalize('NFC', s2)))\n", " seq2 = list(grapheme_clusters(unicodedata.normalize(\"NFC\", s2)))\n",
" return levenshtein(s1, s2)\n", " return levenshtein(seq1, seq2)\n",
"\n" "\n"
] ]
} }
], ],
"source": [ "source": [
"from edit_distance import distance\n", "from qurator.dinglehopper.edit_distance import distance\n",
"print(inspect.getsource(distance))" "print(inspect.getsource(distance))"
] ]
}, },
@ -247,8 +207,7 @@
"source": [ "source": [
"# Edit operations\n", "# Edit operations\n",
"\n", "\n",
"python-Levenshtein supports backtracing, i.e. giving a sequence of edit options that transforms a word to another word:\n", "python-Levenshtein + RapidFuzz supports backtracing, i.e. giving a sequence of edit options that transforms a word to another word:"
"\n"
] ]
}, },
{ {
@ -257,32 +216,20 @@
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "data": {
"output_type": "stream", "text/plain": [
"text": [ "[('replace', 2, 2)]"
"[('insert', 5, 5), ('replace', 5, 6)]\n" ]
] },
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
} }
], ],
"source": [ "source": [
"import Levenshtein\n", "from rapidfuzz.string_metric import levenshtein_editops as editops\n",
"word1 = 'Schlyñ' # with LATIN SMALL LETTER N WITH TILDE\n", "\n",
"word2 = 'Schlym̃' # with LATIN SMALL LETTER M + COMBINING TILDE\n", "editops('Foo', 'Fon')"
"print(Levenshtein.editops(word1, word2))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Note that it does not work with grapheme clusters, but \"characters\", so it gives 2 operations."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Defining our own `editops()`. (This looks a bit wild due to our own tail recursion handling.)"
] ]
}, },
{ {
@ -294,47 +241,12 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"def seq_editops(seq1, seq2):\n", "[('insert', 4, 4)]\n"
" seq1 = list(seq1)\n",
" seq2 = list(seq2)\n",
" m = len(seq1)\n",
" n = len(seq2)\n",
" D = levenshtein_matrix(seq1, seq2)\n",
"\n",
" def _tail_backtrace(i, j, accumulator):\n",
" if i > 0 and D[i - 1, j] + 1 == D[i, j]:\n",
" return partial(_tail_backtrace, i - 1, j, [('delete', i-1, j)] + accumulator)\n",
" if j > 0 and D[i, j - 1] + 1 == D[i, j]:\n",
" return partial(_tail_backtrace, i, j - 1, [('insert', i, j-1)] + accumulator)\n",
" if i > 0 and j > 0 and D[i - 1, j - 1] + 1 == D[i, j]:\n",
" return partial(_tail_backtrace, i - 1, j - 1, [('replace', i-1, j-1)] + accumulator)\n",
" if i > 0 and j > 0 and D[i - 1, j - 1] == D[i, j]:\n",
" return partial(_tail_backtrace, i - 1, j - 1, accumulator) # NOP\n",
" return accumulator\n",
"\n",
" def backtrace(i, j):\n",
" result = partial(_tail_backtrace, i, j, [])\n",
" while isinstance(result, partial):\n",
" result = result()\n",
"\n",
" return result\n",
"\n",
" b = backtrace(m, n)\n",
" return b\n",
"\n",
"def editops(word1, word2):\n",
" # XXX Note that this returns indices to the _grapheme clusters_, not characters!\n",
" word1 = list(grapheme_clusters(unicodedata.normalize('NFC', word1)))\n",
" word2 = list(grapheme_clusters(unicodedata.normalize('NFC', word2)))\n",
" return seq_editops(word1, word2)\n",
"\n"
] ]
} }
], ],
"source": [ "source": [
"from edit_distance import seq_editops, editops\n", "print(editops('Käptn', 'Käpt\\'n'))"
"print(inspect.getsource(seq_editops))\n",
"print(inspect.getsource(editops))"
] ]
}, },
{ {
@ -343,18 +255,15 @@
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"data": { "name": "stdout",
"text/plain": [ "output_type": "stream",
"[('replace', 2, 2)]" "text": [
] "[('delete', 6, 6)]\n"
}, ]
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
} }
], ],
"source": [ "source": [
"editops('Foo', 'Fon')" "print(editops('Delete something', 'Deletesomething'))"
] ]
}, },
{ {
@ -366,66 +275,76 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[('insert', 4, 4)]\n", "[('delete', 1, 1), ('replace', 13, 12), ('insert', 16, 15), ('delete', 23, 23)]\n"
"[('insert', 4, 4)]\n"
] ]
} }
], ],
"source": [ "source": [
"print(editops('Käptn', 'Käpt\\'n'))\n", "print(editops('A more difficult example', 'Amore difficült exampl'))"
"print(Levenshtein.editops('Käptn', 'Käpt\\'n'))"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "markdown",
"execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('delete', 6, 6)]\n",
"[('delete', 6, 6)]\n"
]
}
],
"source": [ "source": [
"print(editops('Delete something', 'Deletesomething'))\n", "Let's try it with a difficult example that needs grapheme cluster handling:"
"print(Levenshtein.editops('Delete something', 'Deletesomething'))"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 13,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "data": {
"output_type": "stream", "text/plain": [
"text": [ "[('insert', 5, 5), ('replace', 5, 6)]"
"[('delete', 1, 1), ('replace', 13, 12), ('insert', 17, 16), ('delete', 23, 23)]\n", ]
"[('delete', 1, 1), ('replace', 13, 12), ('insert', 16, 15), ('delete', 23, 23)]\n" },
] "execution_count": 13,
"metadata": {},
"output_type": "execute_result"
} }
], ],
"source": [ "source": [
"print(editops('A more difficult example', 'Amore difficült exampl'))\n", "word1 = 'Schlyñ' # with LATIN SMALL LETTER N WITH TILDE\n",
"print(Levenshtein.editops('A more difficult example', 'Amore difficült exampl'))" "word2 = 'Schlym̃' # with LATIN SMALL LETTER M + COMBINING TILDE\n",
"\n",
"editops(word1, word2)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"XXX Note that our implementation returns different positions here for the 'insert'. " "That doesn't look right, let's redefine it with grapheme cluster support:"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "code",
"execution_count": 14,
"metadata": {}, "metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"def editops(word1, word2):\n",
" \"\"\"\n",
" Return sequence of edit operations transforming one string to another.\n",
"\n",
" Note that this returns indices to the _grapheme clusters_, not characters!\n",
" \"\"\"\n",
" word1 = list(grapheme_clusters(unicodedata.normalize(\"NFC\", word1)))\n",
" word2 = list(grapheme_clusters(unicodedata.normalize(\"NFC\", word2)))\n",
" return levenshtein_editops(word1, word2)\n",
"\n"
]
}
],
"source": [ "source": [
"Let's try it with a difficult example that needs grapheme cluster handling:" "from qurator.dinglehopper.edit_distance import editops\n",
"print(inspect.getsource(editops))"
] ]
}, },
{ {
@ -455,7 +374,9 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"🎉" "🎉\n",
"\n",
"Here, a problem is that the positions are grapheme cluster positions, not Python character indexes!"
] ]
}, },
{ {
@ -489,22 +410,20 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"def character_error_rate(reference, compared):\n", "def character_error_rate(reference, compared) -> float:\n",
" d = distance(reference, compared)\n", " \"\"\"\n",
" if d == 0:\n", " Compute character error rate.\n",
" return 0\n",
"\n",
" n = len(list(grapheme_clusters(unicodedata.normalize('NFC', reference))))\n",
" if n == 0:\n",
" return float('inf')\n",
"\n", "\n",
" return d/n\n", " :return: character error rate\n",
" \"\"\"\n",
" cer, _ = character_error_rate_n(reference, compared)\n",
" return cer\n",
"\n" "\n"
] ]
} }
], ],
"source": [ "source": [
"from character_error_rate import character_error_rate\n", "from qurator.dinglehopper.character_error_rate import character_error_rate\n",
"print(inspect.getsource(character_error_rate))" "print(inspect.getsource(character_error_rate))"
] ]
}, },
@ -732,16 +651,20 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"def words(s):\n", "@multimethod\n",
"def words(s: str):\n",
" \"\"\"Extract words from a string\"\"\"\n",
"\n",
" # Patch uniseg.wordbreak.word_break to deal with our private use characters. See also\n", " # Patch uniseg.wordbreak.word_break to deal with our private use characters. See also\n",
" # https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt\n", " # https://www.unicode.org/Public/UCD/latest/ucd/auxiliary/WordBreakProperty.txt\n",
" old_word_break = uniseg.wordbreak.word_break\n", " old_word_break = uniseg.wordbreak.word_break\n",
"\n", "\n",
" def new_word_break(c, index=0):\n", " def new_word_break(c, index=0):\n",
" if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area\n", " if 0xE000 <= ord(c) <= 0xF8FF: # Private Use Area\n",
" return 'ALetter'\n", " return \"ALetter\"\n",
" else:\n", " else:\n",
" return old_word_break(c, index)\n", " return old_word_break(c, index)\n",
"\n",
" uniseg.wordbreak.word_break = new_word_break\n", " uniseg.wordbreak.word_break = new_word_break\n",
"\n", "\n",
" # Check if c is an unwanted character, i.e. whitespace, punctuation, or similar\n", " # Check if c is an unwanted character, i.e. whitespace, punctuation, or similar\n",
@ -749,8 +672,8 @@
"\n", "\n",
" # See https://www.fileformat.info/info/unicode/category/index.htm\n", " # See https://www.fileformat.info/info/unicode/category/index.htm\n",
" # and https://unicodebook.readthedocs.io/unicode.html#categories\n", " # and https://unicodebook.readthedocs.io/unicode.html#categories\n",
" unwanted_categories = 'O', 'M', 'P', 'Z', 'S'\n", " unwanted_categories = \"O\", \"M\", \"P\", \"Z\", \"S\"\n",
" unwanted_subcategories = 'Cc', 'Cf'\n", " unwanted_subcategories = \"Cc\", \"Cf\"\n",
"\n", "\n",
" subcat = unicodedata.category(c)\n", " subcat = unicodedata.category(c)\n",
" cat = subcat[0]\n", " cat = subcat[0]\n",
@ -778,7 +701,7 @@
} }
], ],
"source": [ "source": [
"from word_error_rate import words\n", "from qurator.dinglehopper.word_error_rate import words\n",
"print(inspect.getsource(words))\n", "print(inspect.getsource(words))\n",
"\n", "\n",
"list(words(example_text))" "list(words(example_text))"
@ -905,29 +828,15 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"def word_error_rate(reference, compared):\n", "def word_error_rate(reference, compared) -> float:\n",
" if isinstance(reference, str):\n", " wer, _ = word_error_rate_n(reference, compared)\n",
" reference_seq = list(words_normalized(reference))\n", " return wer\n",
" compared_seq = list(words_normalized(compared))\n",
" else:\n",
" reference_seq = list(reference)\n",
" compared_seq = list(compared)\n",
"\n",
" d = levenshtein(reference_seq, compared_seq)\n",
" if d == 0:\n",
" return 0\n",
"\n",
" n = len(reference_seq)\n",
" if n == 0:\n",
" return float('inf')\n",
"\n",
" return d / n\n",
"\n" "\n"
] ]
} }
], ],
"source": [ "source": [
"from word_error_rate import word_error_rate\n", "from qurator.dinglehopper.word_error_rate import word_error_rate\n",
"print(inspect.getsource(word_error_rate))" "print(inspect.getsource(word_error_rate))"
] ]
}, },
@ -1002,9 +911,9 @@
"metadata": { "metadata": {
"hide_input": false, "hide_input": false,
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "dinglehopper-github",
"language": "python", "language": "python",
"name": "python3" "name": "dinglehopper-github"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
@ -1016,7 +925,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.7.3" "version": "3.7.12"
}, },
"toc": { "toc": {
"base_numbering": 1, "base_numbering": 1,

Loading…
Cancel
Save