1use base64::engine::{Engine, general_purpose as b64_general_purpose};
8use hmac::Mac;
9use reqwest::header::{HeaderMap, HeaderValue};
10use serde::{Deserialize, Serialize};
11use sha2::Sha256;
12use std::time::{Duration, SystemTime, UNIX_EPOCH};
13
14pub const HMAC_KEY_LEN: usize = 32;
20
21const RENEGADE_AUTH_HEADER_NAME: &str = "x-renegade-auth";
23
24const RENEGADE_SIG_EXPIRATION_HEADER_NAME: &str = "x-renegade-auth-expiration";
26
27const RENEGADE_HEADER_NAMESPACE: &str = "x-renegade";
29
30type HmacSha256 = hmac::Hmac<Sha256>;
36
37#[derive(Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
39pub struct HmacKey(pub [u8; HMAC_KEY_LEN]);
40
41#[cfg(feature = "darkpool-client")]
42impl From<HmacKey> for renegade_types_core::HmacKey {
43 fn from(key: HmacKey) -> Self {
44 Self(key.0)
45 }
46}
47
48#[cfg(feature = "darkpool-client")]
49impl From<renegade_types_core::HmacKey> for HmacKey {
50 fn from(key: renegade_types_core::HmacKey) -> Self {
51 Self(key.0)
52 }
53}
54
55impl HmacKey {
56 pub fn new(hex: &str) -> Result<Self, String> {
58 Self::from_hex_string(hex)
59 }
60
61 pub fn inner(&self) -> &[u8; HMAC_KEY_LEN] {
63 &self.0
64 }
65
66 #[cfg(feature = "darkpool-client")]
68 pub fn random() -> Self {
69 use rand::RngCore;
70 let mut rng = rand::thread_rng();
71 let mut bytes = [0; HMAC_KEY_LEN];
72 rng.fill_bytes(&mut bytes);
73 Self(bytes)
74 }
75
76 pub fn to_hex_string(&self) -> String {
78 format!("0x{}", hex::encode(self.0))
79 }
80
81 pub fn from_hex_string(hex_str: &str) -> Result<Self, String> {
83 let hex_str = hex_str.strip_prefix("0x").unwrap_or(hex_str);
84 let bytes = hex::decode(hex_str)
85 .map_err(|e| format!("error deserializing bytes from hex string: {e}"))?;
86
87 if bytes.len() != HMAC_KEY_LEN {
88 return Err(format!("expected {HMAC_KEY_LEN} byte HMAC key, got {}", bytes.len()));
89 }
90
91 Ok(Self(bytes.try_into().unwrap()))
92 }
93
94 pub fn to_base64_string(&self) -> String {
96 b64_general_purpose::STANDARD.encode(self.0)
97 }
98
99 pub fn from_base64_string(base64: &str) -> Result<Self, String> {
101 let bytes = b64_general_purpose::STANDARD.decode(base64).map_err(|e| e.to_string())?;
102 if bytes.len() != HMAC_KEY_LEN {
103 return Err(format!("expected {HMAC_KEY_LEN} byte HMAC key, got {}", bytes.len()));
104 }
105
106 Ok(Self(bytes.try_into().unwrap()))
107 }
108
109 pub fn from_bytes(bytes: &[u8]) -> Result<Self, String> {
111 if bytes.len() != HMAC_KEY_LEN {
112 return Err(format!("expected {HMAC_KEY_LEN} byte HMAC key, got {}", bytes.len()));
113 }
114
115 Ok(Self(bytes.try_into().unwrap()))
116 }
117
118 pub fn compute_mac(&self, msg: &[u8]) -> Vec<u8> {
120 let mut hmac =
121 HmacSha256::new_from_slice(self.inner()).expect("hmac can handle all slice lengths");
122 hmac.update(msg);
123 hmac.finalize().into_bytes().to_vec()
124 }
125
126 pub fn verify_mac(&self, msg: &[u8], mac: &[u8]) -> bool {
128 self.compute_mac(msg) == mac
129 }
130}
131
132pub fn add_expiring_auth_to_headers(
138 path: &str,
139 headers: &mut HeaderMap,
140 body: &[u8],
141 key: &HmacKey,
142 expiration: Duration,
143) {
144 let now_millis =
146 SystemTime::now().duration_since(UNIX_EPOCH).expect("negative timestamp").as_millis()
147 as u64;
148 let expiration_ts = now_millis + expiration.as_millis() as u64;
149 headers.insert(RENEGADE_SIG_EXPIRATION_HEADER_NAME, expiration_ts.into());
150
151 let sig = create_request_signature(path, headers, body, key);
153 let b64_sig = b64_general_purpose::STANDARD_NO_PAD.encode(sig);
154 let sig_header = HeaderValue::from_str(&b64_sig).expect("b64 encoding should not fail");
155 headers.insert(RENEGADE_AUTH_HEADER_NAME, sig_header);
156}
157
158fn create_request_signature(
164 path: &str,
165 headers: &HeaderMap,
166 body: &[u8],
167 key: &HmacKey,
168) -> Vec<u8> {
169 let path_bytes = path.as_bytes();
170 let header_bytes = get_header_bytes(headers);
171 let payload = [path_bytes, &header_bytes, body].concat();
172
173 key.compute_mac(&payload)
174}
175
176fn get_header_bytes(headers: &HeaderMap) -> Vec<u8> {
178 let mut headers_buf = Vec::new();
179
180 let mut renegade_headers = headers
182 .iter()
183 .filter_map(|(k, v)| {
184 let key = k.to_string().to_lowercase();
185 if key.starts_with(RENEGADE_HEADER_NAMESPACE) && key != RENEGADE_AUTH_HEADER_NAME {
186 Some((key, v))
187 } else {
188 None
189 }
190 })
191 .collect::<Vec<_>>();
192
193 renegade_headers.sort_by(|a, b| a.0.cmp(&b.0));
195 for (key, value) in renegade_headers {
196 headers_buf.extend_from_slice(key.as_bytes());
197 headers_buf.extend_from_slice(value.as_bytes());
198 }
199
200 headers_buf
201}