Home | History | Annotate | Download | only in xla
      1 # Broadcasting semantics
      2 
      3 This document describes how the broadcasting semantics in XLA work.
      4 
      5 ## What is broadcasting?
      6 
      7 Broadcasting is the process of making arrays with different shapes have
      8 compatible shapes for arithmetic operations. The terminology is borrowed from
      9 Numpy
     10 [(broadcasting)](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
     11 
     12 Broadcasting may be required for operations between multi-dimensional arrays of
     13 different ranks, or between multi-dimensional arrays with different but
     14 compatible shapes. Consider the addition `X+v` where `X` is a matrix (an array
     15 of rank 2) and `v` is a vector (an array of rank 1). To perform element-wise
     16 addition, XLA needs to "broadcast" the vector `v` to the same rank as the
     17 matrix `X`, by replicating `v` a certain number of times. The vector's length
     18 has to match at least one of the dimensions of the matrix.
     19 
     20 For example:
     21 
     22     |1 2 3| + |7 8 9|
     23     |4 5 6|
     24 
     25 The matrix's dimensions are (2,3), the vector's are (3). The vector is broadcast
     26 by replicating it over rows to get:
     27 
     28     |1 2 3| + |7 8 9| = |8  10 12|
     29     |4 5 6|   |7 8 9|   |11 13 15|
     30 
     31 In Numpy, this is called [broadcasting]
     32 (http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
     33 
     34 ## Principles
     35 
     36 The XLA language is as strict and explicit as possible, avoiding implicit and
     37 "magical" features. Such features may make some computations slightly easier to
     38 define, at the cost of more assumptions baked into user code that will be
     39 difficult to change in the long term. If necessary, implicit and magical
     40 features can be added in client-level wrappers.
     41 
     42 In regards to broadcasting, explicit broadcasting specifications on operations
     43 between arrays of different ranks is required. This is different from Numpy,
     44 which infers the specification when possible.
     45 
     46 ## Broadcasting a lower-rank array onto a higher-rank array
     47 
     48 *Scalars* can always be broadcast over arrays without an explicit specification
     49 of broadcasting dimensions. An element-wise binary operation between a scalar
     50 and an array means applying the operation with the scalar for each element in
     51 the array. For example, adding a scalar to a matrix means producing a matrix
     52 each element of which is a sum of the scalar with the corresponding input
     53 matrix's element.
     54 
     55     |1 2 3| + 7 = |8  9  10|
     56     |4 5 6|       |11 12 13|
     57 
     58 Most broadcasting needs can be captured by using a tuple of dimensions on a
     59 binary operation. When the inputs to the operation have different ranks, this
     60 broadcasting tuple specifies which dimension(s) in the **higher-rank** array to
     61 match with the **lower-rank** array.
     62 
     63 Consider the previous example, instead of adding a scalar to a (2,3) matrix, add
     64 a vector of dimension (3) to a matrix of dimensions (2,3). *Without specifying
     65 broadcasting, this operation is invalid.* To correctly request matrix-vector
     66 addition, specify the broadcasting dimension to be (1), meaning the vector's
     67 dimension is matched to dimension 1 of the matrix. In 2D, if dimension 0 is
     68 considered as rows and dimension 1 as columns, this means that each element of
     69 the vector becomes a column of a size matching the number of rows in the matrix:
     70 
     71     |7 8 9| ==> |7 8 9|
     72                 |7 8 9|
     73 
     74 As a more complex example, consider adding a 3-element vector (dimension (3)) to
     75 a 3x3 matrix (dimensions (3,3)). There are two ways broadcasting can happen for
     76 this example:
     77 
     78 (1) A broadcasting dimension of 1 can be used. Each vector element becomes a
     79 column and the vector is duplicated for each row in the matrix.
     80 
     81     |7 8 9| ==> |7 8 9|
     82                 |7 8 9|
     83                 |7 8 9|
     84 
     85 (2) A broadcasting dimension of 0 can be used. Each vector element becomes a row
     86 and the vector is duplicated for each column in the matrix.
     87 
     88      |7| ==> |7 7 7|
     89      |8|     |8 8 8|
     90      |9|     |9 9 9|
     91 
     92 > Note: when adding a 2x3 matrix to a 3-element vector, a broadcasting dimension
     93 > of 0 is invalid.
     94 
     95 The broadcasting dimensions can be a tuple that describes how a smaller rank
     96 shape is broadcast into a larger rank shape. For example, given a 2x3x4 cuboid
     97 and a 3x4 matrix, a broadcasting tuple (1,2) means matching the matrix to
     98 dimensions 1 and 2 of the cuboid.
     99 
    100 This type of broadcast is used in the binary ops in `ComputationBuilder`, if the
    101 `broadcast_dimensions` argument is given. For example, see
    102 [ComputationBuilder::Add](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.cc).
    103 In the XLA source code, this type of broadcasting is sometimes called "InDim"
    104 broadcasting.
    105 
    106 ### Formal definition
    107 
    108 The broadcasting attribute allows matching a lower-rank array to a higher-rank
    109 array, by specifying which dimensions of the higher-rank array to match. For
    110 example, for an array with dimensions MxNxPxQ, a vector with dimension T can be
    111 matched as follows:
    112 
    113               MxNxPxQ
    114 
    115     dim 3:          T
    116     dim 2:        T
    117     dim 1:      T
    118     dim 0:    T
    119 
    120 In each case, T has to be equal to the matching dimension of the higher-rank
    121 array. The vector's values are then broadcast from the matched dimension to all
    122 the other dimensions.
    123 
    124 To match a TxV matrix onto the MxNxPxQ array, a pair of broadcasting dimensions
    125 are used:
    126 
    127               MxNxPxQ
    128     dim 2,3:      T V
    129     dim 1,2:    T V
    130     dim 0,3:  T     V
    131     etc...
    132 
    133 The order of dimensions in the broadcasting tuple has to be the order in which
    134 the lower-rank array's dimensions are expected to match the higher-rank array's
    135 dimensions. The first element in the tuple says which dimension in the
    136 higher-rank array has to match dimension 0 in the lower-rank array. The second
    137 element for dimension 1, and so on. The order of broadcast dimensions has to be
    138 strictly increasing. For example, in the previous example it is illegal to match
    139 V to N and T to P; it is also illegal to match V to both P and N.
    140 
    141 ## Broadcasting similar-rank arrays with degenerate dimensions
    142 
    143 A related broadcasting problem is broadcasting two arrays that have the same
    144 rank but different dimension sizes. Similarly to Numpy's rules, this is only
    145 possible when the arrays are *compatible*. Two arrays are compatible when all
    146 their dimensions are compatible. Two dimensions are compatible if:
    147 
    148 *   They are equal, or
    149 *   One of them is 1 (a "degenerate" dimension)
    150 
    151 When two compatible arrays are encountered, the result shape has the maximum
    152 among the two inputs at every dimension index.
    153 
    154 Examples:
    155 
    156 1.  (2,1) and (2,3) broadcast to (2,3).
    157 2.  (1,2,5) and (7,2,5) broadcast to (7,2,5)
    158 3.  (7,2,5) and (7,1,5) broadcast to (7,2,5)
    159 4.  (7,2,5) and (7,2,6) are incompatible and cannot be broadcast.
    160 
    161 A special case arises, and is also supported, where each of the input arrays has
    162 a degenerate dimension at a different index. In this case, the result is an
    163 "outer operation": (2,1) and (1,3) broadcast to (2,3). For more examples,
    164 consult the [Numpy documentation on
    165 broadcasting](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
    166 
    167 ## Broadcast composition
    168 
    169 Broadcasting of a lower-rank array to a higher-rank array **and** broadcasting
    170 using degenerate dimensions can both be performed in the same binary operation.
    171 For example, a vector of size 4 and an matrix of size 1x2 can be added together
    172 using broadcast dimensions value of (0):
    173 
    174     |1 2 3 4| + [5 6]    // [5 6] is a 1x2 matrix, not a vector.
    175 
    176 First the vector is broadcast up to rank 2 (matrix) using the broadcast
    177 dimensions. The single value (0) in the broadcast dimensions indicates that
    178 dimension zero of the vector matches to dimension zero of the matrix. This
    179 produces an matrix of size 4xM where the value M is chosen to match the
    180 corresponding dimension size in the 1x2 array. Therefore, a 4x2 matrix is
    181 produced:
    182 
    183     |1 1| + [5 6]
    184     |2 2|
    185     |3 3|
    186     |4 4|
    187 
    188 Then "degenerate dimension broadcasting" broadcasts dimension zero of the 1x2
    189 matrix to match the corresponding dimension size of the right hand side:
    190 
    191     |1 1| + |5 6|     |6  7|
    192     |2 2| + |5 6|  =  |7  8|
    193     |3 3| + |5 6|     |8  9|
    194     |4 4| + |5 6|     |9 10|
    195 
    196 A more complicated example is a matrix of size 1x2 added to an array of size
    197 4x3x1 using broadcast dimensions of (1, 2). First the 1x2 matrix is broadcast up
    198 to rank 3 using the broadcast dimensions to produces an intermediate Mx1x2 array
    199 where the dimension size M is determined by the size of the larger operand (the
    200 4x3x1 array) producing a 4x1x2 intermediate array. The M is at dimension 0
    201 (left-most dimension) because the dimensions 1 and 2 are mapped to the
    202 dimensions of the original 1x2 matrix as the broadcast dimension are (1, 2).
    203 This intermediate array can be added to the 4x3x1 matrix using broadcasting of
    204 degenerate dimensions to produce a 4x3x2 array result.
    205