This is an explanatory reference for PyTorch op registration with libtorch. It maps the schema notation you write in m.def("…") to the C++ types you implement, and shows minimal def/impl examples.
Tensor parameter notation
| Notation | Meaning | C++ Equivalent | Example Usage |
|---|---|---|---|
Tensor | Input tensor (immutable) | const torch::Tensor& or torch::Tensor | Tensor input |
Tensor! | Output/mutable tensor (in-place) | torch::Tensor& | Tensor! out |
Tensor? | Optional tensor | const std::optional<torch::Tensor>& | Tensor? alibi_slopes |
Tensor!? | Optional mutable tensor | std::optional<torch::Tensor>& | Tensor!? residual |
Tensor[] | Array of tensors | std::vector<torch::Tensor> | Tensor[] handles |
Tensor(a!)[] | Array of mutable tensors (alias ‘a’) | std::vector<torch::Tensor>& | Tensor(a!)[] key_caches |
Tensor[](b!) | Mutable array of tensors (alias ‘b’) | std::vector<torch::Tensor>& | Tensor[](b!) value_caches |
Scalar parameter types
| Notation | C++ Type | Example Usage |
|---|---|---|
int | int64_t | int num_kv_heads |
float | double | float scale |
bool | bool | bool is_neox |
str | const std::string& | str kv_cache_dtype |
SymInt | int64_t (symbolic integer) | SymInt size_m |
ScalarType | at::ScalarType | ScalarType a_type |
ScalarType? | std::optional<at::ScalarType> | ScalarType? out_type |
Array and collection types
| Notation | C++ Type | Example Usage |
|---|---|---|
int[] | const std::vector<int64_t>& | int[] codebook_partition_sizes |
str[] | std::vector<std::string> | str[] supported_schedules |
Tensor[] | std::vector<torch::Tensor> | Tensor[] compressed_tensors |
Optional parameter notation
| Notation | Meaning | C++ Equivalent |
|---|---|---|
? suffix | Optional parameter | std::optional<T> or const std::optional<T>& |
!? suffix | Optional mutable parameter | std::optional<T>& |
Examples:
Tensor? bias → const std::optional<torch::Tensor>&
int? group_size → std::optional<int64_t>
Tensor!? output_lse → std::optional<torch::Tensor>&
Return types
| Notation | Meaning | C++ Return Type |
|---|---|---|
-> () | Void return (in-place operation) | void |
-> Tensor | Returns single tensor | torch::Tensor |
-> Tensor[] | Returns array of tensors | std::vector<torch::Tensor> |
-> int | Returns integer | int64_t |
-> bool | Returns boolean | bool |
-> str | Returns string | std::string |
-> str[] | Returns array of strings | std::vector<std::string> |
Mutability, aliasing, and defaults
! means the argument is mutated (use torch::Tensor& in C++). No ! means read-only (use const torch::Tensor& when possible). Alias sets like (a!), (b!) are schema-only hints for alias analysis when passing arrays that may share storage.
Defaults can be specified directly in the schema and mirrored in C++ if desired:
// Schema
"swigluoai_and_mul(Tensor! out, Tensor input, float alpha=1.702, float limit=7.0) -> ()"
// C++ signature
void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input, double alpha = 1.702, double limit = 7.0);
Minimal, concrete examples
In-place operation:
// Schema
"silu_and_mul(Tensor! result, Tensor input) -> ()"
// C++ Declaration (ops.h)
void silu_and_mul(torch::Tensor& out, torch::Tensor& input);
Returning a tensor:
// Schema
"aqlm_gemm(Tensor input, Tensor codes, Tensor codebooks, Tensor scales, int[] codebook_partition_sizes, Tensor? bias) -> Tensor"
// C++ Declaration (ops.h)
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks, const torch::Tensor& scales,
const std::vector<int64_t>& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias);
Returning multiple tensors:
// Schema
"top_k_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor? maybe_top_k_arr, int top_k_val, bool deterministic) -> Tensor[]"
// C++ Declaration (ops.h)
std::vector<torch::Tensor> top_k_sampling_from_probs(torch::Tensor probs, torch::Tensor uniform_samples,
std::optional<torch::Tensor> maybe_top_k_arr,
int64_t top_k_val, bool deterministic);
Aliased mutable arrays (alias sets a and b):
// Schema
"copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, Tensor block_mapping) -> ()"
// C++ Declaration (ops.h) - Arrays passed by reference for modification
void copy_blocks(std::vector<torch::Tensor>& key_caches, std::vector<torch::Tensor>& value_caches,
const torch::Tensor& block_mapping);
def, impl, and meta kernels
Register a schema with def, then bind backend implementations with impl:
TORCH_LIBRARY(my_ops, m) {
m.def("weak_ref_tensor(Tensor input) -> Tensor");
m.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
m.impl("weak_ref_tensor", torch::kCPU, &weak_ref_tensor_cpu);
}
Meta kernels mirror the schema and compute output shapes/dtypes without allocation. Register under the Meta key:
at::Tensor my_add_meta(const at::Tensor& a, const at::Tensor& b) {
TORCH_CHECK(a.sizes() == b.sizes(), "shape mismatch");
return at::empty_like(a, a.options());
}
TORCH_LIBRARY_IMPL(my_ops, Meta, m) {
m.impl("my_add", my_add_meta);
}
Optional attributes can be passed to def to override defaults (e.g., stride semantics). Prefer the defaults unless you have a specific, documented reason to change them.
Type mapping summary
| PyTorch Schema Type | C++ Header Type | Notes |
|---|---|---|
Tensor | torch::Tensor or const torch::Tensor& | Input tensor |
Tensor! | torch::Tensor& | Mutable tensor |
Tensor? | std::optional<torch::Tensor> or const std::optional<torch::Tensor>& | Optional tensor |
Tensor!? | std::optional<torch::Tensor>& | Optional mutable tensor |
int | int64_t | Integer parameters |
float | double | Floating point parameters |
bool | bool | Boolean parameters |
str | std::string or const std::string& | String parameters |
SymInt | int64_t | Symbolic integers for dynamic shapes |
ScalarType | at::ScalarType | PyTorch data types |
int[] | std::vector<int64_t> or const std::vector<int64_t>& | Integer arrays |
Tensor[] | std::vector<torch::Tensor> | Tensor arrays |
Use schemas as the contract. Keep C++ signatures consistent with the schema, mark all mutated tensors with !, prefer SymInt for sizes/strides, and make optional parameters truly optional in C++.