renegade_sdk/
auth.rs

1//! Types and utilities for HMAC-based authentication
2//!
3//! Inlines `HmacKey` (from `types-core`) and `add_expiring_auth_to_headers`
4//! (from `external-api`) so the `external-match-client` feature path does not
5//! pull in the full renegade dependency tree.
6
7use base64::engine::{Engine, general_purpose as b64_general_purpose};
8use hmac::Mac;
9use reqwest::header::{HeaderMap, HeaderValue};
10use serde::{Deserialize, Serialize};
11use sha2::Sha256;
12use std::time::{Duration, SystemTime, UNIX_EPOCH};
13
14// -------------
15// | Constants |
16// -------------
17
18/// The length of an HMAC key in bytes
19pub const HMAC_KEY_LEN: usize = 32;
20
21/// The header name for the renegade auth signature
22const RENEGADE_AUTH_HEADER_NAME: &str = "x-renegade-auth";
23
24/// The header name for the renegade auth signature expiration
25const RENEGADE_SIG_EXPIRATION_HEADER_NAME: &str = "x-renegade-auth-expiration";
26
27/// The header namespace to include in the HMAC
28const RENEGADE_HEADER_NAMESPACE: &str = "x-renegade";
29
30// ---------
31// | Types |
32// ---------
33
34/// Type alias for the hmac core implementation
35type HmacSha256 = hmac::Hmac<Sha256>;
36
37/// A type representing a symmetric HMAC key
38#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
39pub struct HmacKey(pub [u8; HMAC_KEY_LEN]);
40
41#[cfg(feature = "darkpool-client")]
42impl From<HmacKey> for renegade_types_core::HmacKey {
43    fn from(key: HmacKey) -> Self {
44        Self(key.0)
45    }
46}
47
48#[cfg(feature = "darkpool-client")]
49impl From<renegade_types_core::HmacKey> for HmacKey {
50    fn from(key: renegade_types_core::HmacKey) -> Self {
51        Self(key.0)
52    }
53}
54
55impl HmacKey {
56    /// Create a new HMAC key from a hex string
57    pub fn new(hex: &str) -> Result<Self, String> {
58        Self::from_hex_string(hex)
59    }
60
61    /// Get the inner bytes
62    pub fn inner(&self) -> &[u8; HMAC_KEY_LEN] {
63        &self.0
64    }
65
66    /// Create a new random HMAC key
67    #[cfg(feature = "darkpool-client")]
68    pub fn random() -> Self {
69        use rand::RngCore;
70        let mut rng = rand::thread_rng();
71        let mut bytes = [0; HMAC_KEY_LEN];
72        rng.fill_bytes(&mut bytes);
73        Self(bytes)
74    }
75
76    /// Convert the HMAC key to a hex string
77    pub fn to_hex_string(&self) -> String {
78        format!("0x{}", hex::encode(self.0))
79    }
80
81    /// Try to convert a hex string to an HMAC key
82    pub fn from_hex_string(hex_str: &str) -> Result<Self, String> {
83        let hex_str = hex_str.strip_prefix("0x").unwrap_or(hex_str);
84        let bytes = hex::decode(hex_str)
85            .map_err(|e| format!("error deserializing bytes from hex string: {e}"))?;
86
87        if bytes.len() != HMAC_KEY_LEN {
88            return Err(format!("expected {HMAC_KEY_LEN} byte HMAC key, got {}", bytes.len()));
89        }
90
91        Ok(Self(bytes.try_into().unwrap()))
92    }
93
94    /// Convert the HMAC key to a base64 string
95    pub fn to_base64_string(&self) -> String {
96        b64_general_purpose::STANDARD.encode(self.0)
97    }
98
99    /// Try to convert a base64 string to an HMAC key
100    pub fn from_base64_string(base64: &str) -> Result<Self, String> {
101        let bytes = b64_general_purpose::STANDARD.decode(base64).map_err(|e| e.to_string())?;
102        if bytes.len() != HMAC_KEY_LEN {
103            return Err(format!("expected {HMAC_KEY_LEN} byte HMAC key, got {}", bytes.len()));
104        }
105
106        Ok(Self(bytes.try_into().unwrap()))
107    }
108
109    /// Try to create an HMAC key from a byte slice
110    pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
111        if bytes.len() != HMAC_KEY_LEN {
112            return Err(format!("expected {HMAC_KEY_LEN} byte HMAC key, got {}", bytes.len()));
113        }
114
115        Ok(Self(bytes.try_into().unwrap()))
116    }
117
118    /// Compute the HMAC of a message
119    pub fn compute_mac(&self, msg: &[u8]) -> Vec<u8> {
120        let mut hmac =
121            HmacSha256::new_from_slice(self.inner()).expect("hmac can handle all slice lengths");
122        hmac.update(msg);
123        hmac.finalize().into_bytes().to_vec()
124    }
125
126    /// Verify the HMAC of a message
127    pub fn verify_mac(&self, msg: &[u8], mac: &[u8]) -> bool {
128        self.compute_mac(msg) == mac
129    }
130}
131
132// --------------------
133// | Public Interface |
134// --------------------
135
136/// Add an auth expiration and signature to a set of headers
137pub fn add_expiring_auth_to_headers(
138    path: &str,
139    headers: &mut HeaderMap,
140    body: &[u8],
141    key: &HmacKey,
142    expiration: Duration,
143) {
144    // Add a timestamp
145    let now_millis =
146        SystemTime::now().duration_since(UNIX_EPOCH).expect("negative timestamp").as_millis()
147            as u64;
148    let expiration_ts = now_millis + expiration.as_millis() as u64;
149    headers.insert(RENEGADE_SIG_EXPIRATION_HEADER_NAME, expiration_ts.into());
150
151    // Add the signature
152    let sig = create_request_signature(path, headers, body, key);
153    let b64_sig = b64_general_purpose::STANDARD_NO_PAD.encode(sig);
154    let sig_header = HeaderValue::from_str(&b64_sig).expect("b64 encoding should not fail");
155    headers.insert(RENEGADE_AUTH_HEADER_NAME, sig_header);
156}
157
158// -----------
159// | Helpers |
160// -----------
161
162/// Create a request signature
163fn create_request_signature(
164    path: &str,
165    headers: &HeaderMap,
166    body: &[u8],
167    key: &HmacKey,
168) -> Vec<u8> {
169    let path_bytes = path.as_bytes();
170    let header_bytes = get_header_bytes(headers);
171    let payload = [path_bytes, &header_bytes, body].concat();
172
173    key.compute_mac(&payload)
174}
175
176/// Get the header bytes to include in an HMAC
177fn get_header_bytes(headers: &HeaderMap) -> Vec<u8> {
178    let mut headers_buf = Vec::new();
179
180    // Filter out non-Renegade headers and the auth header
181    let mut renegade_headers = headers
182        .iter()
183        .filter_map(|(k, v)| {
184            let key = k.to_string().to_lowercase();
185            if key.starts_with(RENEGADE_HEADER_NAMESPACE) && key != RENEGADE_AUTH_HEADER_NAME {
186                Some((key, v))
187            } else {
188                None
189            }
190        })
191        .collect::<Vec<_>>();
192
193    // Sort alphabetically, then add to the buffer
194    renegade_headers.sort_by(|a, b| a.0.cmp(&b.0));
195    for (key, value) in renegade_headers {
196        headers_buf.extend_from_slice(key.as_bytes());
197        headers_buf.extend_from_slice(value.as_bytes());
198    }
199
200    headers_buf
201}