Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can ImageGenerator handle sample weights for pixelwise segmentation? #6629

Closed
mptorr opened this issue May 15, 2017 · 15 comments
Closed

Can ImageGenerator handle sample weights for pixelwise segmentation? #6629

mptorr opened this issue May 15, 2017 · 15 comments

Comments

@mptorr
Copy link
Contributor

mptorr commented May 15, 2017

I'm trying to use sample weighting with ImageGenerator.

  • images and masks are numpy arrays (634, 1, 64, 64)
  • masks have 5 classes (encoded as 0 to 4)
  • sample_weight is an array (634, 64, 64)
  • data augmentation using .flow for 2 generators with identical seed=42 and batch_size=32
  • model being compiled with sample_weight_mode='temporal'
  • last 3 layers in the Unet model are:
conv2d_19 (Conv2D)               (None, 5, 64, 64)     325         dropout_18[0][0]                 
____________________________________________________________________________________________________
conv2d_20 (Conv2D)               (None, 1, 64, 64)     6           conv2d_19[0][0]                  
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 1, 64, 64)     0           conv2d_20[0][0] 

This throws an error:

ValueError: Found a sample_weight array with shape (634, 64, 64). In order to use timestep-wise sample weighting, you should pass a 2D sample_weight array.

If I reshape sample_weight to (634, 4096) I get:

ValueError: Found a sample_weight array with shape (634, 4096) for an input with shape (32, 1, 64, 64). sample_weight cannot be broadcast.

Is this my misunderstanding of how to use ImageGenerator or is it unable to handle this particular situation?

If I do not use sample weights, the model runs and no errors are thrown.

Thanks in advance for any advice you can give.

Keras 2.0.4, Theano 0.9

@kglspl
Copy link

kglspl commented May 17, 2017

@mptorr I am in no way an expert here, but IMHO ImageGenerators are just that - a generator. You should be able to (and you should!) test it separately to make sure it returns the data in the shape you expect. Other than that, there is no magic to image augmentation. It is also my understanding that it runs on CPU and not via Theano/TF.

So I think you should test ImageGenerator separately and it should work when put together. Hope it helps, but I am by no means expert... others might give better advice.

@ahundt
Copy link
Contributor

ahundt commented May 17, 2017

#6538 references several keras repositories with implementations that do this, and the issue itself is for extending ImageDataGenerator to support segmentation, which it does not do right now.

@maderafunk
Copy link

Just an idea, do you compile the model with sample_weight_mode="temporal"?

model.compile(sample_weight_mode="temporal")

@kglspl
Copy link

kglspl commented May 19, 2017

I think he does:

  • model being compiled with sample_weight_mode='temporal'

@mptorr I think I have solved the problem of weighted segmentation with a (admittedly somewhat ugly) hack.

Instead of finding a way to use Keras's built-in losses (which I had trouble finding good documentation for) I have built a custom loss function. The idea I had was that y_true is really just a helper for loss function to determine the loss value. Usually this is the target image (or mask), but it doesn't need to be - it can also encode information about weights.

Once I had this idea it was pretty easy to make an image generator which "encodes" the weights to the mask and then a custom loss function which "decodes" the weights + mask and calculates weighted categorical crossentropy. It seems to work, but I do need to test it some more. YMMV.

@ahundt
Copy link
Contributor

ahundt commented May 20, 2017

@kglspl sounds like it would work reasonably well, though you're right that it is slightly hacky. Any reference code or a link?

@kglspl
Copy link

kglspl commented May 20, 2017

@ahundt I haven't put it anywhere yet, but here's my loss function:

from keras.losses import categorical_crossentropy
import keras.backend as K

def weighted_categorical_crossentropy_fcn_loss(y_true, y_pred):
    # y_true is a matrix of weight-hot vectors (like 1-hot, but they have weights instead of 1s)
    y_true_mask = K.clip(y_true, 0.0, 1.0)  # [0 0 W 0] -> [0 0 1 0] where W >= 1.
    cce = categorical_crossentropy(y_pred, y_true_mask)  # one dim less (each 1hot vector -> float number)
    y_true_weights_maxed = K.max(y_true, axis=-1)  # [0 120 0 0] -> 120 - get weight for each weight-hot vector
    wcce = cce * y_true_weights_maxed
    return K.sum(wcce)

I tried cleaning it up a bit, hope it makes sense. Basically, instead of encoding y_true as 1-hot vector (for example [0 1 0 0]) you multiply it with weight (for example [0 30 0 0] would mean 2nd class is correct, and the weight is 30). In lack of better term I have named this "weight-hot" vector.

Note that I am still a newbie in ML and while this makes sense to me I might be missing something. Also, this is work in progress as it was written yesterday. And lastly, I haven't actually used this particular loss function in NN yet (I did however run a similar one, but wasn't satisfied with its input form). In other words, use at your own risk. ;)

@ahundt
Copy link
Contributor

ahundt commented May 20, 2017

@kglspl Nice start, is clipping the right option or would normalizing be better? There would probably also need to be an additional helper function that makes creating the weight-hot vector easy.

@kglspl
Copy link

kglspl commented May 21, 2017

@ahundt I am not sure if / how normalizing could be used in this case, can you provide an example of what you mean?

The encoding function I use is data-specific, but I guess one could make a generic one, yes.

The good news is that this approach really seems to work. I have successfully trained a NN yesterday and am tweaking the weights now to get better image segmentation. So this process can indeed be used as a (hackish) way to use custom weights in image segmentation with Keras.

@ahundt
Copy link
Contributor

ahundt commented May 21, 2017

@kglspl clipping values to [0, 1] means [-4, -2, -1, 0.5, 1, 2, 4] becomes [0, 0, 0, 0.5, 1, 1, 1]. If you do some variation like dividing by the max of the absolute value of the numbers, you'd get [-1, -0.5, -0.25, 0.125, 0.25, 0.5, 1] which would preserve more of the weighted info you want.

@maderafunk
Copy link

This throws an error:
ValueError: Found a sample_weight array with shape (634, 64, 64). In order to use timestep-wise sample weighting, you should pass a 2D sample_weight array.
If I reshape sample_weight to (634, 4096) I get:

ValueError: Found a sample_weight array with shape (634, 4096) for an input with shape (32, 1, 64, 64). sample_weight cannot be broadcast.

For me, the sample weighting is working for pixelwise segmentation with image generators. are you sure, there is not some other mistake? Looks to me, the input shape should be (32,64,64) and not (634,64,64), as you are using the batch size of 32.

@kglspl
Copy link

kglspl commented May 21, 2017

@ahundt That is by design - I expect weight-hot vectors, which means every element in last vector must be 0, except for one, which is a positive number larger or equal to 1. This allows me to encode both 1-hot vector and a weight into a pixel.
But there are many ways to do it... and I think we are quite off-topic now here. :)

@mptorr
Copy link
Contributor Author

mptorr commented May 21, 2017

@maderafunk This is a circular error caused by the checks that Keras ImageGenerator does regarding the incoming tensors. Weights need to be 2D — however then they don't match the 4D masks. But if you reshape weights to 4D, then Keras throws an error that weights need to be 2D (if using temporal)! So it's inescapable and that reflects how ImageGenerator is not designed for pixelwise segmentation with weighting. Unfortunate, because pixelwise weighting is essential for all of my segmentation tasks.

In the thread on SegDataGenerator (#6538), @ahundt, @Dref360 and @allanzelener discuss how to handle zoom changes to the generated images—this also applies to the weights matrix, because if a region with weight 12 and label 1 is magnified, you can't smooth interpolate the weights and labels at the boundaries with other weights/labels. Perhaps @kglspl's solution may simplify this issue by joining weights and labels into one entity.

@kglspl I'll be experimenting with your suggestion, will get back to you with results. I presume it's OK to discuss here as it relates to Keras/SegDataGenerator.

@maderafunk
Copy link

maderafunk commented May 21, 2017

@mptorr I've modified my ImageGenerator to return masks as well and then I am running another generator around the generator, where I am calculating the sample weights from the new masks. I'm not sure if this is a proper way but it is working for me.

This is some example code, is has not been tested:

def calculateSampleWeights(y):
	#….
	return sample_weights

def dataGenerator(train_data, labels, img_rows, img_cols, num_channels, batch_size, shuffle=True):
	imageGen = ImageGenerator() #modified to return masks as well
	X_out = np.zeros(batch_size, img_rows, img_cols, num_channels)
	Y_out = np.zeros(batch_size, img_rows, img_cols)
	sample_weights = np.zeros(batch_size, img_rows, img_cols)
	if shuffle:
                #optional shuffling comes here
	counter=0
	while 1:
		for X_in, Y_in in zip(train_data, labels):
			X_out[counter], Y_out[counter] = imageGen.flow(X_in, Y_in, batch_size=1)
			sample_weights[counter] = calculateSampleWeights(Y_out[counter])
			if counter>=batch_size:
				counter=0
				yield (X_out, Y_out, sample_weights)

model.fit_generator(dataGenerator(...))

@kglspl
Copy link

kglspl commented May 21, 2017

@mptorr Ok then - I was afraid we were hijacking your issue, but if you see it as a viable workaround I'd be happy to share more details if needed (this approach works and I have successfully used it today to train my NN, but I am still working on making it easier to use).

@ahundt Let me expand the explanation then. :)
Loss function above expects y_true to be "weight-hot" encoded. It is similar to 1-hot encoding, except that it has weight in place of 1 (and of course zeros everywhere else). For instance, a weight-hot vector [0 30 0 0] means that we have 2nd category (that is [0 1 0 0] if you prefer 1-hot encoding) and we want loss to be multiplied by 30.
Which means that clipping is OK for this type of encoding - it gives me the 1-hot vector, meaning it splits information about class from the weight. I am not sure how one could do that with the way you suggested, but I am not arguing against it - there are many ways to encode information.

However today I discovered there is a small limitation to my encoding as explained above - you can't have weight 0. which is sometimes needed to tell the NN not to bother learning from some pixels (near image border and similar). To combat this I made a small adjustment - I have changed encoding so that I encode weight+1 instead of weight. So if for example loss function gets [0 31 0 0] it means that 1-hot vector is [0 1 0 0] and weight is 30. Hope it makes sense.

Note that you can't encode different weights based on the type of misclassification this way - but in my case this is not needed (as I only have 2 classes anyway). If one needs this then some other encoding would be needed, but it shouldn't be difficult to invent appropriate one.

@stale stale bot added the stale label Aug 20, 2017
@stale
Copy link

stale bot commented Aug 20, 2017

This issue has been automatically marked as stale because it has not had recent activity. It will be closed after 30 days if no further activity occurs, but feel free to re-open a closed issue if needed.

@stale stale bot closed this as completed Sep 19, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants