1use base64::engine::general_purpose::URL_SAFE_NO_PAD;
2use base64::Engine;
3use serde::{Deserialize, Serialize};
4
5use crate::storage::cloud::CloudHomeJoinInfo;
6
7#[derive(Serialize, Deserialize)]
8pub struct InviteCode {
9 pub library_id: String,
10 pub library_name: String,
11 pub join_info: CloudHomeJoinInfo,
12 pub owner_pubkey: String,
13}
14
15pub fn encode(code: &InviteCode) -> String {
16 let json = serde_json::to_vec(code).expect("InviteCode is always serializable");
17 URL_SAFE_NO_PAD.encode(&json)
18}
19
20pub fn decode(s: &str) -> Result<InviteCode, JoinCodeError> {
21 let bytes = URL_SAFE_NO_PAD
22 .decode(s.trim())
23 .map_err(|_| JoinCodeError::InvalidBase64)?;
24 serde_json::from_slice(&bytes).map_err(|e| JoinCodeError::InvalidJson(e.to_string()))
25}
26
27#[derive(Serialize, Deserialize)]
28pub struct JoinRequestCode {
29 pub public_key: String,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub email: Option<String>,
32}
33
34pub fn generate_join_request(
37 needs_email: bool,
38 email: String,
39) -> Result<String, crate::keys::KeyError> {
40 let dev_mode = crate::config::Config::is_dev_mode();
41 let global_ks = crate::keys::KeyService::new(dev_mode, "global".to_string());
42 let keypair = global_ks.get_or_create_user_keypair()?;
43
44 let code = JoinRequestCode {
45 public_key: hex::encode(keypair.public_key),
46 email: if needs_email { Some(email) } else { None },
47 };
48
49 Ok(encode_join_request(&code))
50}
51
52pub fn encode_join_request(code: &JoinRequestCode) -> String {
53 let json = serde_json::to_vec(code).expect("JoinRequestCode is always serializable");
54 URL_SAFE_NO_PAD.encode(&json)
55}
56
57pub fn decode_join_request(s: &str) -> Result<JoinRequestCode, JoinCodeError> {
58 let bytes = URL_SAFE_NO_PAD
59 .decode(s.trim())
60 .map_err(|_| JoinCodeError::InvalidBase64)?;
61 serde_json::from_slice(&bytes).map_err(|e| JoinCodeError::InvalidJson(e.to_string()))
62}
63
64pub struct InviteCodeInfo {
66 pub library_id: String,
67 pub library_name: String,
68 pub owner_pubkey: String,
69 pub cloud_provider: crate::config::CloudProvider,
70}
71
72pub fn decode_invite_code_info(code: &str) -> Result<InviteCodeInfo, JoinCodeError> {
74 let invite = decode(code)?;
75 let cloud_provider = invite.join_info.cloud_provider();
76 Ok(InviteCodeInfo {
77 library_id: invite.library_id,
78 library_name: invite.library_name,
79 owner_pubkey: invite.owner_pubkey,
80 cloud_provider,
81 })
82}
83
84#[derive(Debug, thiserror::Error)]
85pub enum JoinCodeError {
86 #[error("invalid base64url encoding")]
87 InvalidBase64,
88 #[error("invalid invite code payload: {0}")]
89 InvalidJson(String),
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[test]
97 fn round_trip_s3() {
98 let code = InviteCode {
99 library_id: "lib-123".into(),
100 library_name: "My Library".into(),
101 join_info: CloudHomeJoinInfo::S3 {
102 bucket: "my-bucket".into(),
103 region: "us-east-1".into(),
104 endpoint: None,
105 access_key: "AKIAEXAMPLE".into(),
106 secret_key: "secret123".into(),
107 key_prefix: None,
108 },
109 owner_pubkey: "deadbeef".into(),
110 };
111 let encoded = encode(&code);
112 let decoded = decode(&encoded).unwrap();
113 assert_eq!(decoded.library_id, "lib-123");
114 assert_eq!(decoded.library_name, "My Library");
115 assert_eq!(decoded.owner_pubkey, "deadbeef");
116 match decoded.join_info {
117 CloudHomeJoinInfo::S3 {
118 bucket,
119 region,
120 endpoint,
121 access_key,
122 secret_key,
123 key_prefix,
124 } => {
125 assert_eq!(bucket, "my-bucket");
126 assert_eq!(region, "us-east-1");
127 assert_eq!(endpoint, None);
128 assert_eq!(access_key, "AKIAEXAMPLE");
129 assert_eq!(secret_key, "secret123");
130 assert_eq!(key_prefix, None);
131 }
132 _ => panic!("expected S3 variant"),
133 }
134 }
135
136 #[test]
137 fn round_trip_s3_with_endpoint() {
138 let code = InviteCode {
139 library_id: "lib-456".into(),
140 library_name: "Shared".into(),
141 join_info: CloudHomeJoinInfo::S3 {
142 bucket: "bucket".into(),
143 region: "eu-west-1".into(),
144 endpoint: Some("https://s3.example.com".into()),
145 access_key: "ak".into(),
146 secret_key: "sk".into(),
147 key_prefix: None,
148 },
149 owner_pubkey: "cafebabe".into(),
150 };
151 let encoded = encode(&code);
152 let decoded = decode(&encoded).unwrap();
153 assert_eq!(decoded.library_id, "lib-456");
154 match decoded.join_info {
155 CloudHomeJoinInfo::S3 { endpoint, .. } => {
156 assert_eq!(endpoint, Some("https://s3.example.com".to_string()));
157 }
158 _ => panic!("expected S3 variant"),
159 }
160 }
161
162 #[test]
163 fn round_trip_google_drive() {
164 let code = InviteCode {
165 library_id: "lib-789".into(),
166 library_name: "Cloud Shared".into(),
167 join_info: CloudHomeJoinInfo::GoogleDrive {
168 folder_id: "abc123".into(),
169 },
170 owner_pubkey: "cafebabe".into(),
171 };
172 let encoded = encode(&code);
173 let decoded = decode(&encoded).unwrap();
174 assert_eq!(decoded.library_id, "lib-789");
175 match decoded.join_info {
176 CloudHomeJoinInfo::GoogleDrive { folder_id } => assert_eq!(folder_id, "abc123"),
177 _ => panic!("expected GoogleDrive variant"),
178 }
179 }
180
181 #[test]
182 fn decode_invalid_base64() {
183 assert!(matches!(
184 decode("not-valid!!!"),
185 Err(JoinCodeError::InvalidBase64)
186 ));
187 }
188
189 #[test]
190 fn decode_invalid_json() {
191 let encoded = URL_SAFE_NO_PAD.encode(b"not json");
192 assert!(matches!(
193 decode(&encoded),
194 Err(JoinCodeError::InvalidJson(_))
195 ));
196 }
197
198 #[test]
199 fn round_trip_cloudkit() {
200 let code = InviteCode {
201 library_id: "lib-ck".into(),
202 library_name: "CloudKit Library".into(),
203 join_info: CloudHomeJoinInfo::CloudKit {
204 share_url: "https://www.icloud.com/share/abc123".into(),
205 },
206 owner_pubkey: "aabbccdd".into(),
207 };
208 let encoded = encode(&code);
209 let decoded = decode(&encoded).unwrap();
210 assert_eq!(decoded.library_id, "lib-ck");
211 match decoded.join_info {
212 CloudHomeJoinInfo::CloudKit { share_url } => {
213 assert_eq!(share_url, "https://www.icloud.com/share/abc123")
214 }
215 _ => panic!("expected CloudKit variant"),
216 }
217 }
218
219 #[test]
220 fn decode_trims_whitespace() {
221 let code = InviteCode {
222 library_id: "lib-ws".into(),
223 library_name: "Trimmed".into(),
224 join_info: CloudHomeJoinInfo::Dropbox {
225 shared_folder_id: "sf1".into(),
226 },
227 owner_pubkey: "aabb".into(),
228 };
229 let encoded = format!(" {} \n", encode(&code));
230 let decoded = decode(&encoded).unwrap();
231 assert_eq!(decoded.library_id, "lib-ws");
232 }
233
234 #[test]
235 fn join_request_round_trip_with_email() {
236 let code = JoinRequestCode {
237 public_key: "abcdef1234567890".into(),
238 email: Some("[email protected]".into()),
239 };
240 let encoded = encode_join_request(&code);
241 let decoded = decode_join_request(&encoded).unwrap();
242 assert_eq!(decoded.public_key, "abcdef1234567890");
243 assert_eq!(decoded.email, Some("[email protected]".to_string()));
244 }
245
246 #[test]
247 fn join_request_round_trip_without_email() {
248 let code = JoinRequestCode {
249 public_key: "deadbeef".into(),
250 email: None,
251 };
252 let encoded = encode_join_request(&code);
253 let decoded = decode_join_request(&encoded).unwrap();
254 assert_eq!(decoded.public_key, "deadbeef");
255 assert_eq!(decoded.email, None);
256 }
257
258 #[test]
259 fn join_request_trims_whitespace() {
260 let code = JoinRequestCode {
261 public_key: "aabbccdd".into(),
262 email: None,
263 };
264 let encoded = format!(" {} \n", encode_join_request(&code));
265 let decoded = decode_join_request(&encoded).unwrap();
266 assert_eq!(decoded.public_key, "aabbccdd");
267 }
268}