Skip to content

Commit

Permalink
resolve conflict with threshold and missing value (#85)
Browse files Browse the repository at this point in the history
Closes #70
  • Loading branch information
J535D165 committed Mar 14, 2019
1 parent 4b8c167 commit 702899e
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
9 changes: 5 additions & 4 deletions recordlinkage/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,14 @@ def _compute_vectorized(self, s_left, s_right):
self.method))

c = str_sim_alg(s_left, s_right)
c = _fillna(c, self.missing_value)

if self.threshold is not None:
return (c >= self.threshold).astype(numpy.float64)
else:
return c
c = c.where((c < self.threshold) | (pandas.isnull(c)), other=1.0)
c = c.where((c >= self.threshold) | (pandas.isnull(c)), other=0.0)

c = _fillna(c, self.missing_value)

return c

class Numeric(BaseCompareFeature):
"""Compute the (partial) similarity between numeric values.
Expand Down
42 changes: 42 additions & 0 deletions tests/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,6 +1183,48 @@ def test_fuzzy(self):
assert (result[result.notnull()] >= 0).all(1).all(0)
assert (result[result.notnull()] <= 1).all(1).all(0)

def test_threshold(self):

A = DataFrame({'col': [u"gretzky", u"gretzky99", u"gretzky", u"gretzky"]})
B = DataFrame({'col': [u"gretzky", u"gretzky", nan, u"wayne"]})
ix = MultiIndex.from_arrays([A.index.values, B.index.values])

comp = recordlinkage.Compare()
comp.string(
'col',
'col',
method="levenshtein",
threshold=0.5,
missing_value=2.0,
label="x_col1"
)
comp.string(
'col',
'col',
method="levenshtein",
threshold=1.0,
missing_value=0.5,
label="x_col2"
)
comp.string(
'col',
'col',
method="levenshtein",
threshold=0.0,
missing_value=nan,
label="x_col3"
)
result = comp.compute(ix, A, B)

expected = Series([1.0, 1.0, 2.0, 0.0], index=ix, name="x_col1")
pdt.assert_series_equal(result["x_col1"], expected)

expected = Series([1.0, 0.0, 0.5, 0.0], index=ix, name="x_col2")
pdt.assert_series_equal(result["x_col2"], expected)

expected = Series([1.0, 1.0, nan, 1.0], index=ix, name="x_col3")
pdt.assert_series_equal(result["x_col3"], expected)

@pytest.mark.parametrize("alg", STRING_SIM_ALGORITHMS)
def test_incorrect_input(self, alg):

Expand Down

0 comments on commit 702899e

Please sign in to comment.