MultiHeadAttention
public struct MultiHeadAttention<Element, Device> : LayerType, Codable where Element : RandomizableType, Device : DeviceType
Multi-Head Attention Layer following Attention Is All You Need.
-
Matrix multiplied with queries before dot product attention
Declaration
Swift
public var qDense: Tensor<Element, Device>
-
Matrix multiplied with keys before dot product attention
Declaration
Swift
public var kDense: Tensor<Element, Device>
-
Matrix multiplied with values before dot product attention
Declaration
Swift
public var vDense: Tensor<Element, Device>
-
Matrix multiplied with result from dot product attention layer
Declaration
Swift
public var fc: Tensor<Element, Device>
-
Undocumented
Declaration
Swift
public var attn: ScaledDotProductAttention<Element, Device>
-
Undocumented
Declaration
Swift
public var norm: LayerNorm<Element, Device>
-
Undocumented
Declaration
Swift
public var dropout: Dropout<Element, Device>
-
Number of attention heads
Declaration
Swift
public let heads: Int
-
Dimensionality of query and key vectors
Declaration
Swift
public let keyDim: Int
-
Dimensionality of value vectors
Declaration
Swift
public let valueDim: Int
-
Lat dimension of keys, queries and values before matrix multiplication
Declaration
Swift
public let hiddenDim: Int
-
Declaration
Swift
public var parameters: [Tensor<Element, Device>] { get }
-
Declaration
Swift
public var parameterPaths: [WritableKeyPath<`Self`, Tensor<Element, Device>>] { get }
-
Multi-Head Attention Layer following Attention Is All You Need.
Declaration
Swift
public init(heads: Int, hiddenDim: Int, keyDim: Int, valueDim: Int, dropout: Float = 0.1)
Parameters
heads
Number of attention heads
hiddenDim
Last dimension of keys, queries and values
keyDim
Last dimesion of keys
valueDim
Intermediate last dimension of values
dropout
Dropout rate
-
Computes multi-head scaled dot product attention using the provided query, key and value vector as well as the provided mask.
Additionally applies dropout, a residual connection and layer normalization.
Declaration
Parameters
inputs
Tuple containing queries of shape [batchSize, queryCount, hiddenDim], keys of shape [batchSize, keyCount, hiddenDim] and values of shape [batchSize, valueCount, hiddenDim] as well as an optional mask that may be used to prevent attention to certain elements outside of the batch or in future timesteps. Mask must be broadcastable to shape [batchSize, heads, queryCount, keyCount] and have 1 entries for all elements that should be blocked.
Return Value
Normalized scaled dot product attended values