tf.cast equivalent in pytorch?

10,492

Check out the PyTorch Documentation

As they mentioned:

print(x.dtype) # Prints "torch.int64", currently 64-bit integer type
x = x.type(torch.FloatTensor)
print(x.dtype) # Prints "torch.float32", now 32-bit float
print(x.float()) # Still "torch.float32"
print(x.type(torch.DoubleTensor)) # Prints "tensor([0., 1., 2., 3.], dtype=torch.float64)"
print(x.type(torch.LongTensor)) # Cast back to int-64, prints "tensor([0, 1, 2, 3])"
Share:
10,492
SarwarKhan
Author by

SarwarKhan

Ph.D student in NCCU. working on image processing and Artificial intelligence. I like to swim and code not at the same time :)

Updated on June 11, 2022

Comments

  • SarwarKhan
    SarwarKhan almost 2 years

    I am new to PyTorch. TensorFlow has an API tf.cast() and tf.shape(). the tf.cast has specific purpose in TensorFlow, is there anything equivalent in torch? i have tensor x= tensor(shape(128,64,32,32)): tf.shape(x) create tensor of dimension 1 x.shape create the true dimension. i need to use tf.shape(x) in torch.

    tf.cast has a different role than just changing tensor dtype in torch.

    did anyone have equivalent API in torch/PyTorch.