diff --git a/ocrolib/common.py b/ocrolib/common.py index 276743d4..25730bdf 100644 --- a/ocrolib/common.py +++ b/ocrolib/common.py @@ -50,7 +50,7 @@ def pyargsort(seq,cmp=cmp,key=lambda x:x): return sorted(range(len(seq)),key=lambda x:key(seq.__getitem__(x)),cmp=cmp) def renumber_by_xcenter(seg): - objects = [(slice(0,0),slice(0,0))]+measurements.find_objects(seg) + objects = [(slice(0,0),slice(0,0))]+find_objects(seg) def xc(o): return mean((o[1].start,o[1].stop)) xs = array([xc(o) for o in objects]) order = argsort(xs) @@ -58,20 +58,32 @@ def xc(o): return mean((o[1].start,o[1].stop)) for i,j in enumerate(order): segmap[j] = i return segmap[seg] -def flexible_find_objects(image): - """Like measurements.find_objects, but tries to - be a bit more flexible about the datatypes it accepts.""" - # first try the default type - try: return measurements.find_objects(image) +def label(image,**kw): + """measurements.label fails to document what types it accepts, + and it fails randomly with different types on different + platforms. This tries to work around that.""" + try: return measurements.label(image,**kw) except: pass - types = ["int32","int64","int16"] + types = ["int32","uint32","int64","unit64","int16","uint16"] for t in types: - # try with type conversions - try: return measurements.find_objects(array(image,dtype=t)) + try: return measurements.label(array(image,dtype=t),**kw) except: pass # let it raise the same exception as before - return measurements.find_objects(image) + return measurements.label(image,**kw) +def find_objects(image,**kw): + """measurements.find_objects fails to document what types it accepts, + and it fails randomly with different types on different + platforms. This tries to work around that.""" + try: return measurements.find_objects(image,**kw) + except: pass + types = ["int32","uint32","int64","unit64","int16","uint16"] + for t in types: + try: return measurements.find_objects(array(image,dtype=t),**kw) + except: pass + # let it raise the same exception as before + return measurements.find_objects(image,**kw) + def rgb2int(image): """Converts a rank 3 array with RGB values stored in the last axis into a rank 2 array containing 32 bit RGB values.""" @@ -109,7 +121,7 @@ def setImageMasked(self,image,mask=None,lo=None,hi=None): labels,correspondence = renumber_labels_ordered(labels,correspondence=1) self.labels = labels self.correspondence = correspondence - self.objects = [None]+flexible_find_objects(labels) + self.objects = [None]+find_objects(labels) def setPageColumns(self,image): """Set the image to be iterated over. This should be an RGB image, ndim==3, dtype=='B'. This iterates over the columns.""" @@ -811,7 +823,7 @@ def estimate_xheight(line,scale=1.0,debug=0): return bottom-top,bottom def keep_marked(image,markers): - labels,_ = measurements.label(image) + labels,_ = label(image) imshow(sin(17.1*labels),cmap=cm.jet) marked = unique(labels*(markers!=0)) print marked @@ -840,7 +852,7 @@ def latin_filter(line,scale=1.0,r=1.2,debug=0): def remove_noise(line,minsize=8): bin = (line>0.5*amax(line)) - labels,n = measurements.label(bin) + labels,n = label(bin) sums = measurements.sum(bin,labels,range(n+1)) sums = sums[labels] good = minimum(bin,1-(sums>0)*(sums0.5,center) center = filters.maximum_filter(center,(2,2)) - center,_ = measurements.label(center) + center,_ = common.label(center) center = psegutils.spread_labels(center) center *= image return center @@ -136,7 +137,7 @@ def charseg(self,line): tracks = dplineseg2(line,imweight=self.imweight,bweight=self.bweight, diagweight=self.diagweight,debug=self.debug,r=self.r) tracks = array(tracks<0.5*amax(tracks),'i') - tracks,_ = measurements.label(tracks) + tracks,_ = common.label(tracks) self.tracks = tracks rsegs = psegutils.spread_labels(tracks) rsegs = rsegs*(line>0.5*amax(line)) diff --git a/ocrolib/psegutils.py b/ocrolib/psegutils.py index 9b8f418d..1258bad5 100644 --- a/ocrolib/psegutils.py +++ b/ocrolib/psegutils.py @@ -4,6 +4,7 @@ from scipy.ndimage import filters,interpolation,morphology,measurements from scipy import stats from scipy.misc import imsave +import common class record: def __init__(self,**kw): self.__dict__.update(kw) @@ -80,7 +81,7 @@ def spread_labels(labels,maxdist=9999999): return spread def keep_marked(image,markers): - labels,_ = measurements.label(image) + labels,_ = common.label(image) marked = unique(labels*(markers!=0)) kept = in1d(labels.ravel(),marked) return (image!=0)*kept.reshape(*labels.shape) @@ -100,7 +101,7 @@ def correspondences(labels1,labels2): def propagate_labels_simple(regions,labels): """Spread the labels to the corresponding regions.""" - rlabels,_ = measurements.label(regions) + rlabels,_ = common.label(regions) cors = correspondences(rlabels,labels) outputs = zeros(amax(rlabels)+1,'i') for o,i in cors.T: outputs[o] = i @@ -109,7 +110,7 @@ def propagate_labels_simple(regions,labels): def propagate_labels(regions,labels,conflict=0): """Spread the labels to the corresponding regions.""" - rlabels,_ = measurements.label(regions) + rlabels,_ = common.label(regions) cors = correspondences(rlabels,labels) outputs = zeros(amax(rlabels)+1,'i') oops = -(1<<30) @@ -126,8 +127,8 @@ def A(s): return W(s)*H(s) def M(s): return mean([s[0].start,s[0].stop]),mean([s[1].start,s[1].stop]) def binary_objects(binary): - labels,n = measurements.label(binary) - objects = measurements.find_objects(labels) + labels,n = common.label(binary) + objects = common.find_objects(labels) return objects def estimate_scale(binary): @@ -153,7 +154,7 @@ def compute_boxmap(binary,scale,threshold=(.5,4),dtype='i'): def compute_lines(segmentation,scale): """Given a line segmentation map, computes a list of tuples consisting of 2D slices and masked images.""" - lobjects = measurements.find_objects(segmentation) + lobjects = common.find_objects(segmentation) lines = [] for i,o in enumerate(lobjects): if o is None: continue @@ -291,8 +292,8 @@ def rgbshow(r,g,b=None,gn=1,cn=0,ab=0,**kw): imshow(clip(combo,0,1),**kw) def select_regions(binary,f,min=0,nbest=100000): - labels,n = measurements.label(binary) - objects = measurements.find_objects(labels) + labels,n = common.label(binary) + objects = common.find_objects(labels) scores = [f(o) for o in objects] best = argsort(scores) keep = zeros(len(objects)+1,'B')