class SHAInet::EmbeddingLayer

Overview

Simple embedding lookup table. Maps integer token IDs to vectors of floats.

Defined in:

shainet/text/embedding_layer.cr

Constructors

Instance Method Summary

Instance methods inherited from class SHAInet::Layer

activation_function : Float32 | Float64 | Int32 | Int64 -> {Float64, Float64} activation_function, activations : Matrix(Float64) activations, biases : Matrix(Float64) biases, biases=(biases : Matrix(Float64)) biases=, clone clone, input_sums : Matrix(Float64) input_sums, input_sums=(input_sums : Matrix(Float64)) input_sums=, inspect inspect, l_size : Int32 l_size, n_type : String n_type, n_type=(n_type : String) n_type=, neurons : Array(SHAInet::Neuron) neurons, neurons=(neurons : Array(SHAInet::Neuron)) neurons=, propagate_forward_exp(prev_layer : Layer) propagate_forward_exp, random_seed random_seed, sigma_primes : Matrix(Float64) sigma_primes, size : Int32 size, type_change(new_neuron_type : String) type_change, weights : Matrix(Float64) weights, weights=(weights : Matrix(Float64)) weights=

Constructor methods inherited from class SHAInet::Layer

new(n_type : String, l_size : Int32, activation_function : ActivationFunction = SHAInet.sigmoid) new

Constructor Detail

def self.new(l_size : Int32, activation_function : ActivationFunction = SHAInet.none) #

[View source]

Instance Method Detail

def accumulate_gradient #

Accumulate gradient for the last embedded ids


[View source]
def apply_gradients(lr : Float64) #

Update embeddings using stored gradients and clear them


[View source]
def current_ids : Array(Int32) #

[View source]
def embed(id : Int32) : Array(Float64) #

Set the neuron activations for this layer according to the embedding of the provided token id. Returns the embedding vector.


[View source]
def embeddings : Hash(Int32, Array(Float64)) #

[View source]
def embeddings=(embeddings : Hash(Int32, Array(Float64))) #

[View source]
def gradients : Hash(Int32, Array(Float64)) #

[View source]
def gradients=(gradients : Hash(Int32, Array(Float64))) #

[View source]
def lookup(id : Int32) : Array(Float64) #

Retrieve embedding vector for the given token id. If the token id does not exist in the table, it is initialized with random values.


[View source]