{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MultiWayIf #-}
module Data.Text.Metrics
(
levenshtein
, levenshteinNorm
, damerauLevenshtein
, damerauLevenshteinNorm
, overlap
, jaccard
, hamming
, jaro
, jaroWinkler )
where
import Control.Monad
import Control.Monad.ST
import Data.Map.Strict (Map)
import Data.Ratio
import Data.Text
import GHC.Exts (inline)
import qualified Data.Map.Strict as M
import qualified Data.Text as T
import qualified Data.Text.Unsafe as TU
import qualified Data.Vector.Unboxed.Mutable as VUM
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative
#endif
levenshtein :: Text -> Text -> Int
levenshtein a b = fst (levenshtein_ a b)
levenshteinNorm :: Text -> Text -> Ratio Int
levenshteinNorm = norm levenshtein_
levenshtein_ :: Text -> Text -> (Int, Int)
levenshtein_ a b
| T.null a = (lenb, lenm)
| T.null b = (lena, lenm)
| otherwise = runST $ do
let v_len = lenb + 1
v <- VUM.unsafeNew (v_len * 2)
let gov !i =
when (i < v_len) $ do
VUM.unsafeWrite v i i
gov (i + 1)
goi !i !na !v0 !v1 = do
let !(TU.Iter ai da) = TU.iter a na
goj !j !nb =
when (j < lenb) $ do
let !(TU.Iter bj db) = TU.iter b nb
cost = if ai == bj then 0 else 1
x <- (+ 1) <$> VUM.unsafeRead v (v1 + j)
y <- (+ 1) <$> VUM.unsafeRead v (v0 + j + 1)
z <- (+ cost) <$> VUM.unsafeRead v (v0 + j)
VUM.unsafeWrite v (v1 + j + 1) (min x (min y z))
goj (j + 1) (nb + db)
when (i < lena) $ do
VUM.unsafeWrite v v1 (i + 1)
goj 0 0
goi (i + 1) (na + da) v1 v0
gov 0
goi 0 0 0 v_len
ld <- VUM.unsafeRead v (lenb + if even lena then 0 else v_len)
return (ld, lenm)
where
lena = T.length a
lenb = T.length b
lenm = max lena lenb
{-# INLINE levenshtein_ #-}
damerauLevenshtein :: Text -> Text -> Int
damerauLevenshtein a b = fst (damerauLevenshtein_ a b)
damerauLevenshteinNorm :: Text -> Text -> Ratio Int
damerauLevenshteinNorm = norm damerauLevenshtein_
damerauLevenshtein_ :: Text -> Text -> (Int, Int)
damerauLevenshtein_ a b
| T.null a = (lenb, lenm)
| T.null b = (lena, lenm)
| otherwise = runST $ do
let v_len = lenb + 1
v <- VUM.unsafeNew (v_len * 3)
let gov !i =
when (i < v_len) $ do
VUM.unsafeWrite v i i
gov (i + 1)
goi !i !na !ai_1 !v0 !v1 !v2 = do
let !(TU.Iter ai da) = TU.iter a na
goj !j !nb !bj_1 =
when (j < lenb) $ do
let !(TU.Iter bj db) = TU.iter b nb
cost = if ai == bj then 0 else 1
x <- (+ 1) <$> VUM.unsafeRead v (v1 + j)
y <- (+ 1) <$> VUM.unsafeRead v (v0 + j + 1)
z <- (+ cost) <$> VUM.unsafeRead v (v0 + j)
let g = min x (min y z)
val <- (+ cost) <$> VUM.unsafeRead v (v2 + j - 1)
VUM.unsafeWrite v (v1 + j + 1) $
if i > 0 && j > 0 && ai == bj_1 && ai_1 == bj && val < g
then val
else g
goj (j + 1) (nb + db) bj
when (i < lena) $ do
VUM.unsafeWrite v v1 (i + 1)
goj 0 0 'a'
goi (i + 1) (na + da) ai v1 v2 v0
gov 0
goi 0 0 'a' 0 v_len (v_len * 2)
ld <- VUM.unsafeRead v (lenb + (lena `mod` 3) * v_len)
return (ld, lenm)
where
lena = T.length a
lenb = T.length b
lenm = max lena lenb
{-# INLINE damerauLevenshtein_ #-}
overlap :: Text -> Text -> Ratio Int
overlap a b =
if d == 0
then 1 % 1
else intersectionSize (mkTextMap a) (mkTextMap b) % d
where
d = min (T.length a) (T.length b)
jaccard :: Text -> Text -> Ratio Int
jaccard a b =
if d == 0
then 1 % 1
else intersectionSize ma mb % d
where
ma = mkTextMap a
mb = mkTextMap b
d = unionSize ma mb
mkTextMap :: Text -> Map Char Int
mkTextMap = T.foldl' f M.empty
where
f m ch = M.insertWith (+) ch 1 m
{-# INLINE mkTextMap #-}
intersectionSize :: Map Char Int -> Map Char Int -> Int
intersectionSize a b = M.foldl' (+) 0 (M.intersectionWith min a b)
{-# INLINE intersectionSize #-}
unionSize :: Map Char Int -> Map Char Int -> Int
unionSize a b = M.foldl' (+) 0 (M.unionWith max a b)
{-# INLINE unionSize #-}
hamming :: Text -> Text -> Maybe Int
hamming a b =
if T.length a == T.length b
then Just (go 0 0 0)
else Nothing
where
go !na !nb !r =
let !(TU.Iter cha da) = TU.iter a na
!(TU.Iter chb db) = TU.iter b nb
in if | na == len -> r
| cha /= chb -> go (na + da) (nb + db) (r + 1)
| otherwise -> go (na + da) (nb + db) r
len = TU.lengthWord16 a
jaro :: Text -> Text -> Ratio Int
jaro a b =
if T.null a || T.null b
then 0 % 1
else runST $ do
let lena = T.length a
lenb = T.length b
d =
if lena >= 2 && lenb >= 2
then max lena lenb `quot` 2 - 1
else 0
v <- VUM.replicate lenb (0 :: Int)
r <- VUM.replicate 3 (0 :: Int)
let goi !i !na !fromb = do
let !(TU.Iter ai da) = TU.iter a na
(from, fromb') =
if i >= d
then (i - d, fromb + TU.iter_ b fromb)
else (0, 0)
to = min (i + d + 1) lenb
goj !j !nb =
when (j < to) $ do
let !(TU.Iter bj db) = TU.iter b nb
used <- (== 1) <$> VUM.unsafeRead v j
if not used && ai == bj
then do
tj <- VUM.unsafeRead r 0
if j < tj
then VUM.unsafeModify r (+ 1) 2
else VUM.unsafeWrite r 0 j
VUM.unsafeWrite v j 1
VUM.unsafeModify r (+ 1) 1
else goj (j + 1) (nb + db)
when (i < lena) $ do
goj from fromb
goi (i + 1) (na + da) fromb'
goi 0 0 0
m <- VUM.unsafeRead r 1
t <- VUM.unsafeRead r 2
return $
if m == 0
then 0 % 1
else ((m % lena) +
(m % lenb) +
((m - t) % m)) / 3
jaroWinkler :: Text -> Text -> Ratio Int
jaroWinkler a b = dj + (1 % 10) * l * (1 - dj)
where
dj = inline (jaro a b)
l = fromIntegral (commonPrefix a b)
commonPrefix :: Text -> Text -> Int
commonPrefix a b = go 0 0 0
where
go !na !nb !r =
let !(TU.Iter cha da) = TU.iter a na
!(TU.Iter chb db) = TU.iter b nb
in if | na == lena -> r
| nb == lenb -> r
| cha == chb -> go (na + da) (nb + db) (r + 1)
| otherwise -> r
lena = TU.lengthWord16 a
lenb = TU.lengthWord16 b
{-# INLINE commonPrefix #-}
norm :: (Text -> Text -> (Int, Int)) -> Text -> Text -> Ratio Int
norm f a b =
let (r, l) = f a b
in if r == 0
then 1 % 1
else 1 % 1 - r % l
{-# INLINE norm #-}