LCSとdiff

Luaで書いてみました。

定義

a, bはデータの列で長さはそれぞれn, mとする。
a[1], a[2], ..., a[n]
b[1], b[2], ..., b[m]

LCS (Longest Common Subsequence) lcs(a, b, n, m) は
(xs[1], ys[1]), (xs[2], ys[2]), ..., (xs[l], ys[l])
と表され、以下を満たす。

性質

a[n] == b[m] ならば (xs[l] == n and ys[l] == m) であるようなLCSが存在する。

この場合は、lcs(a, b, n-1, m-1) を求めて末尾に (n, m) を付け加える。

a[n] ~= b[m]ならばすべてのLCSで (xs[l] ~= n or ys[l] ~= m) である。
この場合は、lcs(a, b, n, m-1), lcs(a, b, n-1, m) のうち長いほうが lcs(a, b, n, m) の値である。

アルゴリズム・実装

素朴に再帰すると遅すぎて使えません。

メモ化するという手もありますが、やや面倒そうなのでやめます。

LCSといえば動的計画法です。
動的計画法は広い意味ではメモ化のようなものですが、
「副作用のない関数の呼び出しを高速化する」という 狭い意味でのメモ化とは異なります。

O(MN)のアルゴリズム

先に lcs(a, b, x, y) の長さを n×m の配列に格納した後で xs, ys を求めます。

lcs = function(a, b, n, m)
  local get, set
  do
    local t = {}
    get = function(x, y)
      local val = t[(n+1) * y + x]
      if val == nil then return 0
      else return val end
    end
    set = function(x, y, z)
      t[(n+1) * y + x] = z
    end
  end
  for y = 1, m do
    for x = 1, n do
      local z
      if a[x] == b[y] then z = get(x-1, y-1)+1
      else z = math.max(get(x-1, y), get(x, y-1)) end
      set(x, y, z)
    end
  end
  local l = get(n, m)
  local xs, ys = {}, {}
  local x, y = n, m
  while l > 0 do
    if a[x] == b[y] then
      xs[l], ys[l] = x, y
      l = l-1; x = x-1; y = y-1
    elseif get(x-1, y) < get(x, y-1) then
      y = y-1
    else
      x = x-1
    end
  end
  return xs, ys
end

O((M+N)D)のアルゴリズム

俗に O(ND) と呼ばれているようです。
D は、aを編集してbをつくるとき要素を挿入または削除する回数です。
a と b が似ていると小さく、異なると大きくなる「編集距離」です。

m×n の配列を埋めるという方針は同じですが、
d = x + y - 2*l (l はLCSの長さ) の値が小さい箇所 (x,y) を優先的に埋めていきます。
d = 0, 1, 2, ... とループを回し d = D の時に打ち切られます。
すると、必要のない箇所は埋めずに済んでしまうので、計算量が減ります。

例えば下の表のように、ある d まで表を埋めて、
ln-2, m-3 と ln-1, m の値が分かっているとします。
(lx, y = (x+y-d)/2 の形。)
この状況で (??) の値を求めなくても (*) を計算できます。
なぜなら、
ln, m-1 ≤ ln-1, m-2+1 ≤ ln-2, m-3+2 = (n+m-d-5)/2 + 2 = (n+m-d-1)/2 = ln-1, m.
∴ ln, m-1 ≤ ln-1, m.

m-3m-2m-1m
n-2ln-2, m-3 = (n+m-d-5)/2------
n-1--(??)--ln-1, m = (n+m-d-1)/2
n----(??)(*)
lcs = function(a, b, n, m)
  local v = {}
  local path, set_path
  do
    local t = {}
    path = function(x, y) return t[(m+n+1) * y + x] end
    set_path = function(x, y, z) t[(m+n+1) * y + x] = z end
  end
  local make_lcs = function(l)
    local xs, ys = {}, {}
    local x, y = n, m
    while x > 0 and y > 0 do
      if path(x, y) == "vertical" then
        y = y-1
      elseif path(x, y) == "horizontal" then
        x = x-1
      else
        assert(path(x, y) == "diagonal")
        xs[l], ys[l] = x, y
        l = l-1; x = x-1; y = y-1;
      end
    end
    return xs, ys
  end

  v[1] = 0

  for d = 0, m + n do
    for k = -d, d, 2 do
      local x, y
      if k == -d or (k ~= d and v[k-1] < v[k+1]) then
        x = v[k+1]; y = x - k
        set_path(x, y, "vertical")
      else
        x = v[k-1] + 1; y = x - k
        set_path(x, y, "horizontal")
      end
      while x < n and y < m and a[x+1] == b[y+1] do
        x, y = x+1, y+1
        set_path(x, y, "diagonal")
      end
      v[k] = x
      if x == n and y == m then return make_lcs((m + n - d) / 2) end
    end
  end
end

diff

Unix系のOSでお馴染みのdiffですが、 Luaで書いておけばUnix以外の環境でも簡単に動かせると思います。

diff = function(old, new)
  local a, b = {}, {}
  for line in io.lines(old) do table.insert(a, line) end
  for line in io.lines(new) do table.insert(b, line) end
  local xs, ys = lcs(a, b, #a, #b)
  table.insert(xs, #a+1)
  table.insert(ys, #b+1)
  local i, x, y = 1, 1, 1
  while true do
    if x == xs[i] and y == ys[i] then
      if i == #xs then return end
      i = i+1; x = x+1; y = y+1
    elseif x == xs[i] then
      assert(y < ys[i])
      io.write(string.format("%da", x-1))
      if y == ys[i]-1 then
        io.write(string.format("%d\n", y))
      else
        io.write(string.format("%d,%d\n", y, ys[i] - 1))
      end
      repeat
        io.write(string.format("> %s\n", b[y]))
        y = y+1
      until y == ys[i]
    elseif y == ys[i] then
      assert(x < xs[i])
      if x == xs[i]-1 then
        io.write(string.format("%dd", x))
      else
        io.write(string.format("%d,%dd", x, xs[i]-1))
      end
      io.write(string.format("%d\n", y-1))
      repeat
        io.write(string.format("< %s\n", a[x]))
        x = x+1
      until x == xs[i]
    else
      assert(x < xs[i] and y < ys[i])
      if x == xs[i]-1 then
        io.write(string.format("%dc", x))
      else
        io.write(string.format("%d,%dc", x, xs[i]-1))
      end
      if y == ys[i]-1 then
        io.write(string.format("%d\n", y))
      else
        io.write(string.format("%d,%d\n", y, ys[i]-1))
      end
      repeat
        io.write(string.format("< %s\n", a[x]))
        x = x+1
      until x == xs[i]
      io.write("---\n")
      repeat
        io.write(string.format("> %s\n", b[y]))
        y = y+1
      until y == ys[i]
    end
  end
end

参考

inserted by FC2 system