Skip to content

map_labels fails for n_classes > 1 #315

@hvgazula

Description

@hvgazula
    File "/net/vast-storage/scratch/vast/gablab/hgazula/nobrainer/nobrainer/dataset.py", line 338, in None  *
        lambda x, y: (x, tf.one_hot(y, self.n_classes))

    TypeError: Value passed to parameter 'indices' has DataType float32 not in list of allowed values: uint8, int8, int32, int64

Solution: cast y into one of the int types and the rest should flow.

TODO: check the same with n_classes = 2

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions