Skip to content

Commit

Permalink
(Hopefully) fixed data type problems with find_objects/label on 32 bi…
Browse files Browse the repository at this point in the history
…t machines.
  • Loading branch information
Tom committed Jul 15, 2012
1 parent ebcd7d9 commit faec00d
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 28 deletions.
38 changes: 25 additions & 13 deletions ocrolib/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,40 @@ def pyargsort(seq,cmp=cmp,key=lambda x:x):
return sorted(range(len(seq)),key=lambda x:key(seq.__getitem__(x)),cmp=cmp)

def renumber_by_xcenter(seg):
objects = [(slice(0,0),slice(0,0))]+measurements.find_objects(seg)
objects = [(slice(0,0),slice(0,0))]+find_objects(seg)
def xc(o): return mean((o[1].start,o[1].stop))
xs = array([xc(o) for o in objects])
order = argsort(xs)
segmap = zeros(amax(seg)+1,'i')
for i,j in enumerate(order): segmap[j] = i
return segmap[seg]

def flexible_find_objects(image):
"""Like measurements.find_objects, but tries to
be a bit more flexible about the datatypes it accepts."""
# first try the default type
try: return measurements.find_objects(image)
def label(image,**kw):
"""measurements.label fails to document what types it accepts,
and it fails randomly with different types on different
platforms. This tries to work around that."""
try: return measurements.label(image,**kw)
except: pass
types = ["int32","int64","int16"]
types = ["int32","uint32","int64","unit64","int16","uint16"]
for t in types:
# try with type conversions
try: return measurements.find_objects(array(image,dtype=t))
try: return measurements.label(array(image,dtype=t),**kw)
except: pass
# let it raise the same exception as before
return measurements.find_objects(image)
return measurements.label(image,**kw)

def find_objects(image,**kw):
"""measurements.find_objects fails to document what types it accepts,
and it fails randomly with different types on different
platforms. This tries to work around that."""
try: return measurements.find_objects(image,**kw)
except: pass
types = ["int32","uint32","int64","unit64","int16","uint16"]
for t in types:
try: return measurements.find_objects(array(image,dtype=t),**kw)
except: pass
# let it raise the same exception as before
return measurements.find_objects(image,**kw)

def rgb2int(image):
"""Converts a rank 3 array with RGB values stored in the
last axis into a rank 2 array containing 32 bit RGB values."""
Expand Down Expand Up @@ -109,7 +121,7 @@ def setImageMasked(self,image,mask=None,lo=None,hi=None):
labels,correspondence = renumber_labels_ordered(labels,correspondence=1)
self.labels = labels
self.correspondence = correspondence
self.objects = [None]+flexible_find_objects(labels)
self.objects = [None]+find_objects(labels)
def setPageColumns(self,image):
"""Set the image to be iterated over. This should be an RGB image,
ndim==3, dtype=='B'. This iterates over the columns."""
Expand Down Expand Up @@ -811,7 +823,7 @@ def estimate_xheight(line,scale=1.0,debug=0):
return bottom-top,bottom

def keep_marked(image,markers):
labels,_ = measurements.label(image)
labels,_ = label(image)
imshow(sin(17.1*labels),cmap=cm.jet)
marked = unique(labels*(markers!=0))
print marked
Expand Down Expand Up @@ -840,7 +852,7 @@ def latin_filter(line,scale=1.0,r=1.2,debug=0):

def remove_noise(line,minsize=8):
bin = (line>0.5*amax(line))
labels,n = measurements.label(bin)
labels,n = label(bin)
sums = measurements.sum(bin,labels,range(n+1))
sums = sums[labels]
good = minimum(bin,1-(sums>0)*(sums<minsize))
Expand Down
6 changes: 3 additions & 3 deletions ocrolib/docproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from scipy import stats
from scipy.ndimage import measurements
from pylab import *

from common import *
import common

def avg(*args):
return mean(args)
Expand All @@ -19,7 +19,7 @@ def seg_boxes(seg,math=0):
coordinates are used (however, the order of the values in the
tuple doesn't change)."""
seg = array(seg,'uint32')
slices = measurements.find_objects(seg)
slices = common.find_objects(seg)
h = seg.shape[0]
result = []
for i in range(len(slices)):
Expand Down Expand Up @@ -117,7 +117,7 @@ def bbox(image):
"""Compute the bounding box for the pixels in the image."""
assert len(image.shape)==2,"wrong shape: "+str(image.shape)
image = array(image!=0,'uint32')
cs = scipy.ndimage.measurements.find_objects(image)
cs = common.find_objects(image)
if len(cs)<1: return None
c = cs[0]
return (c[0].start,c[1].start,c[0].stop,c[1].stop)
Expand Down
4 changes: 2 additions & 2 deletions ocrolib/grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def setSegmentation(self,segmentation,cseg=0,preferred=None):
# print sorted(correspondences)
self.pre2seg = correspondences
# compute the bounding boxes in order
boxes = [None]+measurements.find_objects(segmentation)
boxes = [None]+common.find_objects(segmentation)
n = len(boxes)
# now consider groups of boxes
groups = []
Expand Down Expand Up @@ -107,7 +107,7 @@ def setCSegmentation(self,segmentation):
the groups corresponding to each labeled object. Objects should be labeled
consecutively."""
# compute the bounding boxes in order
boxes = [None] + measurements.find_objects(segmentation)
boxes = [None] + common.find_objects(segmentation)
n = len(boxes)
# now consider groups of boxes
groups = []
Expand Down
5 changes: 3 additions & 2 deletions ocrolib/lineseg.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pylab import *
from scipy.ndimage import filters,morphology,measurements
import psegutils
import common



Expand Down Expand Up @@ -85,7 +86,7 @@ def ccslineseg(image):
center = filters.maximum_filter(center,(3,3))
center = psegutils.keep_marked(image>0.5,center)
center = filters.maximum_filter(center,(2,2))
center,_ = measurements.label(center)
center,_ = common.label(center)
center = psegutils.spread_labels(center)
center *= image
return center
Expand Down Expand Up @@ -136,7 +137,7 @@ def charseg(self,line):
tracks = dplineseg2(line,imweight=self.imweight,bweight=self.bweight,
diagweight=self.diagweight,debug=self.debug,r=self.r)
tracks = array(tracks<0.5*amax(tracks),'i')
tracks,_ = measurements.label(tracks)
tracks,_ = common.label(tracks)
self.tracks = tracks
rsegs = psegutils.spread_labels(tracks)
rsegs = rsegs*(line>0.5*amax(line))
Expand Down
17 changes: 9 additions & 8 deletions ocrolib/psegutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from scipy.ndimage import filters,interpolation,morphology,measurements
from scipy import stats
from scipy.misc import imsave
import common

class record:
def __init__(self,**kw): self.__dict__.update(kw)
Expand Down Expand Up @@ -80,7 +81,7 @@ def spread_labels(labels,maxdist=9999999):
return spread

def keep_marked(image,markers):
labels,_ = measurements.label(image)
labels,_ = common.label(image)
marked = unique(labels*(markers!=0))
kept = in1d(labels.ravel(),marked)
return (image!=0)*kept.reshape(*labels.shape)
Expand All @@ -100,7 +101,7 @@ def correspondences(labels1,labels2):

def propagate_labels_simple(regions,labels):
"""Spread the labels to the corresponding regions."""
rlabels,_ = measurements.label(regions)
rlabels,_ = common.label(regions)
cors = correspondences(rlabels,labels)
outputs = zeros(amax(rlabels)+1,'i')
for o,i in cors.T: outputs[o] = i
Expand All @@ -109,7 +110,7 @@ def propagate_labels_simple(regions,labels):

def propagate_labels(regions,labels,conflict=0):
"""Spread the labels to the corresponding regions."""
rlabels,_ = measurements.label(regions)
rlabels,_ = common.label(regions)
cors = correspondences(rlabels,labels)
outputs = zeros(amax(rlabels)+1,'i')
oops = -(1<<30)
Expand All @@ -126,8 +127,8 @@ def A(s): return W(s)*H(s)
def M(s): return mean([s[0].start,s[0].stop]),mean([s[1].start,s[1].stop])

def binary_objects(binary):
labels,n = measurements.label(binary)
objects = measurements.find_objects(labels)
labels,n = common.label(binary)
objects = common.find_objects(labels)
return objects

def estimate_scale(binary):
Expand All @@ -153,7 +154,7 @@ def compute_boxmap(binary,scale,threshold=(.5,4),dtype='i'):
def compute_lines(segmentation,scale):
"""Given a line segmentation map, computes a list
of tuples consisting of 2D slices and masked images."""
lobjects = measurements.find_objects(segmentation)
lobjects = common.find_objects(segmentation)
lines = []
for i,o in enumerate(lobjects):
if o is None: continue
Expand Down Expand Up @@ -291,8 +292,8 @@ def rgbshow(r,g,b=None,gn=1,cn=0,ab=0,**kw):
imshow(clip(combo,0,1),**kw)

def select_regions(binary,f,min=0,nbest=100000):
labels,n = measurements.label(binary)
objects = measurements.find_objects(labels)
labels,n = common.label(binary)
objects = common.find_objects(labels)
scores = [f(o) for o in objects]
best = argsort(scores)
keep = zeros(len(objects)+1,'B')
Expand Down

0 comments on commit faec00d

Please sign in to comment.