I am trying to learn Tensorflow with Java. I exported my model in Python to ".pb" format, and I am trying to do predictions with it. The model works well in Python, but when I try to predict in Java, it fails to give any valid output.
Is there something clearly wrong here? Or does the problem lay in the way the model was exported.
public List<String> predict(ChessBoard chessBoard) {
// Create a tensor
TFloat32 inputTensor = chessBoard.encodeBoardToTensor();
// Feed the tensor to the model
Result res = this.model.call(Collections.singletonMap("input_layer", inputTensor));
Tensor output = res.get(0);
// Make predictions using the model
float[] predictions = new float[4672];
output.asRawTensor().data().asFloats().read(predictions);
...
}
I checked the format of the inputTensor, and it matches the one in Python.
The Tensor shape is (1,8,8,12), a chessboard shape with batch data.