Home | History | Annotate | Download | only in tensorflow
      1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 package org.tensorflow;
     17 
     18 import static java.nio.charset.StandardCharsets.UTF_8;
     19 import static org.junit.Assert.assertArrayEquals;
     20 import static org.junit.Assert.assertEquals;
     21 import static org.junit.Assert.assertTrue;
     22 import static org.junit.Assert.fail;
     23 
     24 import java.nio.ByteBuffer;
     25 import java.nio.ByteOrder;
     26 import java.nio.DoubleBuffer;
     27 import java.nio.FloatBuffer;
     28 import java.nio.IntBuffer;
     29 import java.nio.LongBuffer;
     30 import org.junit.Test;
     31 import org.junit.runner.RunWith;
     32 import org.junit.runners.JUnit4;
     33 import org.tensorflow.types.UInt8;
     34 
     35 /** Unit tests for {@link org.tensorflow.Tensor}. */
     36 @RunWith(JUnit4.class)
     37 public class TensorTest {
     38   private static final double EPSILON = 1e-7;
     39   private static final float EPSILON_F = 1e-7f;
     40 
     41   @Test
     42   public void createWithByteBuffer() {
     43     double[] doubles = {1d, 2d, 3d, 4d};
     44     long[] doubles_shape = {4};
     45     boolean[] bools = {true, false, true, false};
     46     long[] bools_shape = {4};
     47     byte[] bools_ = TestUtil.bool2byte(bools);
     48     byte[] strings = "test".getBytes(UTF_8);
     49     long[] strings_shape = {};
     50     byte[] strings_; // raw TF_STRING
     51     try (Tensor<String> t = Tensors.create(strings)) {
     52       ByteBuffer to = ByteBuffer.allocate(t.numBytes());
     53       t.writeTo(to);
     54       strings_ = to.array();
     55     }
     56 
     57     // validate creating a tensor using a byte buffer
     58     {
     59       try (Tensor<Boolean> t = Tensor.create(Boolean.class, bools_shape, ByteBuffer.wrap(bools_))) {
     60         boolean[] actual = t.copyTo(new boolean[bools_.length]);
     61         for (int i = 0; i < bools.length; ++i) {
     62           assertEquals("" + i, bools[i], actual[i]);
     63         }
     64       }
     65 
     66       // note: the buffer is expected to contain raw TF_STRING (as per C API)
     67       try (Tensor<String> t =
     68           Tensor.create(String.class, strings_shape, ByteBuffer.wrap(strings_))) {
     69         assertArrayEquals(strings, t.bytesValue());
     70       }
     71     }
     72 
     73     // validate creating a tensor using a direct byte buffer (in host order)
     74     {
     75       ByteBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder());
     76       buf.asDoubleBuffer().put(doubles);
     77       try (Tensor<Double> t = Tensor.create(Double.class, doubles_shape, buf)) {
     78         double[] actual = new double[doubles.length];
     79         assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
     80       }
     81     }
     82 
     83     // validate shape checking
     84     try (Tensor<Boolean> t =
     85         Tensor.create(Boolean.class, new long[bools_.length * 2], ByteBuffer.wrap(bools_))) {
     86       fail("should have failed on incompatible buffer");
     87     } catch (IllegalArgumentException e) {
     88       // expected
     89     }
     90   }
     91 
     92   @Test
     93   public void createFromBufferWithNonNativeByteOrder() {
     94     double[] doubles = {1d, 2d, 3d, 4d};
     95     DoubleBuffer buf =
     96         ByteBuffer.allocate(8 * doubles.length)
     97             .order(
     98                 ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN
     99                     ? ByteOrder.BIG_ENDIAN
    100                     : ByteOrder.LITTLE_ENDIAN)
    101             .asDoubleBuffer()
    102             .put(doubles);
    103     buf.flip();
    104     try (Tensor<Double> t = Tensor.create(new long[] {doubles.length}, buf)) {
    105       double[] actual = new double[doubles.length];
    106       assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
    107     }
    108   }
    109 
    110   @Test
    111   public void createWithTypedBuffer() {
    112     int[] ints = {1, 2, 3, 4};
    113     float[] floats = {1f, 2f, 3f, 4f};
    114     double[] doubles = {1d, 2d, 3d, 4d};
    115     long[] longs = {1L, 2L, 3L, 4L};
    116     long[] shape = {4};
    117 
    118     // validate creating a tensor using a typed buffer
    119     {
    120       try (Tensor<Double> t = Tensor.create(shape, DoubleBuffer.wrap(doubles))) {
    121         double[] actual = new double[doubles.length];
    122         assertArrayEquals(doubles, t.copyTo(actual), EPSILON);
    123       }
    124       try (Tensor<Float> t = Tensor.create(shape, FloatBuffer.wrap(floats))) {
    125         float[] actual = new float[floats.length];
    126         assertArrayEquals(floats, t.copyTo(actual), EPSILON_F);
    127       }
    128       try (Tensor<Integer> t = Tensor.create(shape, IntBuffer.wrap(ints))) {
    129         int[] actual = new int[ints.length];
    130         assertArrayEquals(ints, t.copyTo(actual));
    131       }
    132       try (Tensor<Long> t = Tensor.create(shape, LongBuffer.wrap(longs))) {
    133         long[] actual = new long[longs.length];
    134         assertArrayEquals(longs, t.copyTo(actual));
    135       }
    136     }
    137 
    138     // validate shape-checking
    139     {
    140       try (Tensor<Double> t =
    141           Tensor.create(new long[doubles.length + 1], DoubleBuffer.wrap(doubles))) {
    142         fail("should have failed on incompatible buffer");
    143       } catch (IllegalArgumentException e) {
    144         // expected
    145       }
    146       try (Tensor<Float> t = Tensor.create(new long[floats.length + 1], FloatBuffer.wrap(floats))) {
    147         fail("should have failed on incompatible buffer");
    148       } catch (IllegalArgumentException e) {
    149         // expected
    150       }
    151       try (Tensor<Integer> t = Tensor.create(new long[ints.length + 1], IntBuffer.wrap(ints))) {
    152         fail("should have failed on incompatible buffer");
    153       } catch (IllegalArgumentException e) {
    154         // expected
    155       }
    156       try (Tensor<Long> t = Tensor.create(new long[longs.length + 1], LongBuffer.wrap(longs))) {
    157         fail("should have failed on incompatible buffer");
    158       } catch (IllegalArgumentException e) {
    159         // expected
    160       }
    161     }
    162   }
    163 
    164   @Test
    165   public void writeTo() {
    166     int[] ints = {1, 2, 3};
    167     float[] floats = {1f, 2f, 3f};
    168     double[] doubles = {1d, 2d, 3d};
    169     long[] longs = {1L, 2L, 3L};
    170     boolean[] bools = {true, false, true};
    171 
    172     try (Tensor<Integer> tints = Tensors.create(ints);
    173         Tensor<Float> tfloats = Tensors.create(floats);
    174         Tensor<Double> tdoubles = Tensors.create(doubles);
    175         Tensor<Long> tlongs = Tensors.create(longs);
    176         Tensor<Boolean> tbools = Tensors.create(bools)) {
    177 
    178       // validate that any datatype is readable with ByteBuffer (content, position)
    179       {
    180         ByteBuffer bbuf = ByteBuffer.allocate(1024).order(ByteOrder.nativeOrder());
    181 
    182         bbuf.clear(); // FLOAT
    183         tfloats.writeTo(bbuf);
    184         assertEquals(tfloats.numBytes(), bbuf.position());
    185         bbuf.flip();
    186         assertEquals(floats[0], bbuf.asFloatBuffer().get(0), EPSILON);
    187         bbuf.clear(); // DOUBLE
    188         tdoubles.writeTo(bbuf);
    189         assertEquals(tdoubles.numBytes(), bbuf.position());
    190         bbuf.flip();
    191         assertEquals(doubles[0], bbuf.asDoubleBuffer().get(0), EPSILON);
    192         bbuf.clear(); // INT32
    193         tints.writeTo(bbuf);
    194         assertEquals(tints.numBytes(), bbuf.position());
    195         bbuf.flip();
    196         assertEquals(ints[0], bbuf.asIntBuffer().get(0));
    197         bbuf.clear(); // INT64
    198         tlongs.writeTo(bbuf);
    199         assertEquals(tlongs.numBytes(), bbuf.position());
    200         bbuf.flip();
    201         assertEquals(longs[0], bbuf.asLongBuffer().get(0));
    202         bbuf.clear(); // BOOL
    203         tbools.writeTo(bbuf);
    204         assertEquals(tbools.numBytes(), bbuf.position());
    205         bbuf.flip();
    206         assertEquals(bools[0], bbuf.get(0) != 0);
    207       }
    208 
    209       // validate the use of direct buffers
    210       {
    211         DoubleBuffer buf =
    212             ByteBuffer.allocateDirect(tdoubles.numBytes())
    213                 .order(ByteOrder.nativeOrder())
    214                 .asDoubleBuffer();
    215         tdoubles.writeTo(buf);
    216         assertTrue(buf.isDirect());
    217         assertEquals(tdoubles.numElements(), buf.position());
    218         assertEquals(doubles[0], buf.get(0), EPSILON);
    219       }
    220 
    221       // validate typed buffers (content, position)
    222       {
    223         FloatBuffer buf = FloatBuffer.allocate(tfloats.numElements());
    224         tfloats.writeTo(buf);
    225         assertEquals(tfloats.numElements(), buf.position());
    226         assertEquals(floats[0], buf.get(0), EPSILON);
    227       }
    228       {
    229         DoubleBuffer buf = DoubleBuffer.allocate(tdoubles.numElements());
    230         tdoubles.writeTo(buf);
    231         assertEquals(tdoubles.numElements(), buf.position());
    232         assertEquals(doubles[0], buf.get(0), EPSILON);
    233       }
    234       {
    235         IntBuffer buf = IntBuffer.allocate(tints.numElements());
    236         tints.writeTo(buf);
    237         assertEquals(tints.numElements(), buf.position());
    238         assertEquals(ints[0], buf.get(0));
    239       }
    240       {
    241         LongBuffer buf = LongBuffer.allocate(tlongs.numElements());
    242         tlongs.writeTo(buf);
    243         assertEquals(tlongs.numElements(), buf.position());
    244         assertEquals(longs[0], buf.get(0));
    245       }
    246 
    247       // validate byte order conversion
    248       {
    249         DoubleBuffer foreignBuf =
    250             ByteBuffer.allocate(tdoubles.numBytes())
    251                 .order(
    252                     ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN
    253                         ? ByteOrder.BIG_ENDIAN
    254                         : ByteOrder.LITTLE_ENDIAN)
    255                 .asDoubleBuffer();
    256         tdoubles.writeTo(foreignBuf);
    257         foreignBuf.flip();
    258         double[] actual = new double[foreignBuf.remaining()];
    259         foreignBuf.get(actual);
    260         assertArrayEquals(doubles, actual, EPSILON);
    261       }
    262 
    263       // validate that incompatible buffers are rejected
    264       {
    265         IntBuffer badbuf1 = IntBuffer.allocate(128);
    266         try {
    267           tbools.writeTo(badbuf1);
    268           fail("should have failed on incompatible buffer");
    269         } catch (IllegalArgumentException e) {
    270           // expected
    271         }
    272         FloatBuffer badbuf2 = FloatBuffer.allocate(128);
    273         try {
    274           tbools.writeTo(badbuf2);
    275           fail("should have failed on incompatible buffer");
    276         } catch (IllegalArgumentException e) {
    277           // expected
    278         }
    279         DoubleBuffer badbuf3 = DoubleBuffer.allocate(128);
    280         try {
    281           tbools.writeTo(badbuf3);
    282           fail("should have failed on incompatible buffer");
    283         } catch (IllegalArgumentException e) {
    284           // expected
    285         }
    286         LongBuffer badbuf4 = LongBuffer.allocate(128);
    287         try {
    288           tbools.writeTo(badbuf4);
    289           fail("should have failed on incompatible buffer");
    290         } catch (IllegalArgumentException e) {
    291           // expected
    292         }
    293       }
    294     }
    295   }
    296 
    297   @Test
    298   public void scalars() {
    299     try (Tensor<Float> t = Tensors.create(2.718f)) {
    300       assertEquals(DataType.FLOAT, t.dataType());
    301       assertEquals(0, t.numDimensions());
    302       assertEquals(0, t.shape().length);
    303       assertEquals(2.718f, t.floatValue(), EPSILON_F);
    304     }
    305 
    306     try (Tensor<Double> t = Tensors.create(3.1415)) {
    307       assertEquals(DataType.DOUBLE, t.dataType());
    308       assertEquals(0, t.numDimensions());
    309       assertEquals(0, t.shape().length);
    310       assertEquals(3.1415, t.doubleValue(), EPSILON);
    311     }
    312 
    313     try (Tensor<Integer> t = Tensors.create(-33)) {
    314       assertEquals(DataType.INT32, t.dataType());
    315       assertEquals(0, t.numDimensions());
    316       assertEquals(0, t.shape().length);
    317       assertEquals(-33, t.intValue());
    318     }
    319 
    320     try (Tensor<Long> t = Tensors.create(8589934592L)) {
    321       assertEquals(DataType.INT64, t.dataType());
    322       assertEquals(0, t.numDimensions());
    323       assertEquals(0, t.shape().length);
    324       assertEquals(8589934592L, t.longValue());
    325     }
    326 
    327     try (Tensor<Boolean> t = Tensors.create(true)) {
    328       assertEquals(DataType.BOOL, t.dataType());
    329       assertEquals(0, t.numDimensions());
    330       assertEquals(0, t.shape().length);
    331       assertTrue(t.booleanValue());
    332     }
    333 
    334     final byte[] bytes = {1, 2, 3, 4};
    335     try (Tensor<String> t = Tensors.create(bytes)) {
    336       assertEquals(DataType.STRING, t.dataType());
    337       assertEquals(0, t.numDimensions());
    338       assertEquals(0, t.shape().length);
    339       assertArrayEquals(bytes, t.bytesValue());
    340     }
    341   }
    342 
    343   @Test
    344   public void nDimensional() {
    345     double[] vector = {1.414, 2.718, 3.1415};
    346     try (Tensor<Double> t = Tensors.create(vector)) {
    347       assertEquals(DataType.DOUBLE, t.dataType());
    348       assertEquals(1, t.numDimensions());
    349       assertArrayEquals(new long[] {3}, t.shape());
    350 
    351       double[] got = new double[3];
    352       assertArrayEquals(vector, t.copyTo(got), EPSILON);
    353     }
    354 
    355     int[][] matrix = {{1, 2, 3}, {4, 5, 6}};
    356     try (Tensor<Integer> t = Tensors.create(matrix)) {
    357       assertEquals(DataType.INT32, t.dataType());
    358       assertEquals(2, t.numDimensions());
    359       assertArrayEquals(new long[] {2, 3}, t.shape());
    360 
    361       int[][] got = new int[2][3];
    362       assertArrayEquals(matrix, t.copyTo(got));
    363     }
    364 
    365     long[][][] threeD = {
    366       {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}},
    367     };
    368     try (Tensor<Long> t = Tensors.create(threeD)) {
    369       assertEquals(DataType.INT64, t.dataType());
    370       assertEquals(3, t.numDimensions());
    371       assertArrayEquals(new long[] {2, 5, 1}, t.shape());
    372 
    373       long[][][] got = new long[2][5][1];
    374       assertArrayEquals(threeD, t.copyTo(got));
    375     }
    376 
    377     boolean[][][][] fourD = {
    378       {{{false, false, false, true}, {false, false, true, false}}},
    379       {{{false, false, true, true}, {false, true, false, false}}},
    380       {{{false, true, false, true}, {false, true, true, false}}},
    381     };
    382     try (Tensor<Boolean> t = Tensors.create(fourD)) {
    383       assertEquals(DataType.BOOL, t.dataType());
    384       assertEquals(4, t.numDimensions());
    385       assertArrayEquals(new long[] {3, 1, 2, 4}, t.shape());
    386 
    387       boolean[][][][] got = new boolean[3][1][2][4];
    388       assertArrayEquals(fourD, t.copyTo(got));
    389     }
    390   }
    391 
    392   @Test
    393   public void testNDimensionalStringTensor() {
    394     byte[][][] matrix = new byte[4][3][];
    395     for (int i = 0; i < 4; ++i) {
    396       for (int j = 0; j < 3; ++j) {
    397         matrix[i][j] = String.format("(%d, %d) = %d", i, j, i << j).getBytes(UTF_8);
    398       }
    399     }
    400     try (Tensor<String> t = Tensors.create(matrix)) {
    401       assertEquals(DataType.STRING, t.dataType());
    402       assertEquals(2, t.numDimensions());
    403       assertArrayEquals(new long[] {4, 3}, t.shape());
    404 
    405       byte[][][] got = t.copyTo(new byte[4][3][]);
    406       assertEquals(4, got.length);
    407       for (int i = 0; i < 4; ++i) {
    408         assertEquals(String.format("%d", i), 3, got[i].length);
    409         for (int j = 0; j < 3; ++j) {
    410           assertArrayEquals(String.format("(%d, %d)", i, j), matrix[i][j], got[i][j]);
    411         }
    412       }
    413     }
    414   }
    415 
    416   @Test
    417   public void testUInt8Tensor() {
    418     byte[] vector = new byte[] {1, 2, 3, 4};
    419     try (Tensor<UInt8> t = Tensor.create(vector, UInt8.class)) {
    420       assertEquals(DataType.UINT8, t.dataType());
    421       assertEquals(1, t.numDimensions());
    422       assertArrayEquals(new long[] {4}, t.shape());
    423 
    424       byte[] got = t.copyTo(new byte[4]);
    425       assertArrayEquals(vector, got);
    426     }
    427   }
    428 
    429   @Test
    430   public void testCreateFromArrayOfBoxed() {
    431     Integer[] vector = new Integer[] {1, 2, 3, 4};
    432     try (Tensor<Integer> t = Tensor.create(vector, Integer.class)) {
    433       fail("Tensor.create() should fail because it was given an array of boxed values");
    434     } catch (IllegalArgumentException e) {
    435         // The expected exception
    436     }
    437   }
    438 
    439   @Test
    440   public void failCreateOnMismatchedDimensions() {
    441     int[][][] invalid = new int[3][1][];
    442     for (int x = 0; x < invalid.length; ++x) {
    443       for (int y = 0; y < invalid[x].length; ++y) {
    444         invalid[x][y] = new int[x + y + 1];
    445       }
    446     }
    447     try (Tensor<?> t = Tensor.create(invalid)) {
    448       fail("Tensor.create() should fail because of differing sizes in the 3rd dimension");
    449     } catch (IllegalArgumentException e) {
    450       // The expected exception.
    451     }
    452   }
    453 
    454   @Test
    455   public void failCopyToOnIncompatibleDestination() {
    456     try (final Tensor<Integer> matrix = Tensors.create(new int[][] {{1, 2}, {3, 4}})) {
    457       try {
    458         matrix.copyTo(new int[2]);
    459         fail("should have failed on dimension mismatch");
    460       } catch (IllegalArgumentException e) {
    461         // The expected exception.
    462       }
    463 
    464       try {
    465         matrix.copyTo(new float[2][2]);
    466         fail("should have failed on DataType mismatch");
    467       } catch (IllegalArgumentException e) {
    468         // The expected exception.
    469       }
    470 
    471       try {
    472         matrix.copyTo(new int[2][3]);
    473         fail("should have failed on shape mismatch");
    474       } catch (IllegalArgumentException e) {
    475         // The expected exception.
    476       }
    477     }
    478   }
    479 
    480   @Test
    481   public void failCopyToOnScalar() {
    482     try (final Tensor<Integer> scalar = Tensors.create(3)) {
    483       try {
    484         scalar.copyTo(3);
    485         fail("copyTo should fail on scalar tensors, suggesting use of primitive accessors instead");
    486       } catch (IllegalArgumentException e) {
    487         // The expected exception.
    488       }
    489     }
    490   }
    491 
    492   @Test
    493   public void failOnArbitraryObject() {
    494     try (Tensor<?> t = Tensor.create(new Object())) {
    495       fail("should fail on creating a Tensor with a Java object that has no equivalent DataType");
    496     } catch (IllegalArgumentException e) {
    497       // The expected exception.
    498     }
    499   }
    500 
    501   @Test
    502   public void failOnZeroDimension() {
    503     try (Tensor<Integer> t = Tensors.create(new int[3][0][1])) {
    504       fail("should fail on creating a Tensor where one of the dimensions is 0");
    505     } catch (IllegalArgumentException e) {
    506       // The expected exception.
    507     }
    508   }
    509 
    510   @Test
    511   public void useAfterClose() {
    512     int n = 4;
    513     Tensor<?> t = Tensor.create(n);
    514     t.close();
    515     try {
    516       t.intValue();
    517     } catch (NullPointerException e) {
    518       // The expected exception.
    519     }
    520   }
    521 
    522   @Test
    523   public void fromHandle() {
    524     // fromHandle is a package-visible method intended for use when the C TF_Tensor object has been
    525     // created independently of the Java code. In practice, two Tensor instances MUST NOT have the
    526     // same native handle.
    527     //
    528     // An exception is made for this test, where the pitfalls of this is avoided by not calling
    529     // close() on both Tensors.
    530     final float[][] matrix = {{1, 2, 3}, {4, 5, 6}};
    531     try (Tensor<Float> src = Tensors.create(matrix)) {
    532       Tensor<Float> cpy = Tensor.fromHandle(src.getNativeHandle()).expect(Float.class);
    533       assertEquals(src.dataType(), cpy.dataType());
    534       assertEquals(src.numDimensions(), cpy.numDimensions());
    535       assertArrayEquals(src.shape(), cpy.shape());
    536       assertArrayEquals(matrix, cpy.copyTo(new float[2][3]));
    537     }
    538   }
    539 }
    540