]> sourceware.org Git - systemtap.git/blob - stap-client-connect.c
Merge branch 'master' of ssh://sources.redhat.com/git/systemtap
[systemtap.git] / stap-client-connect.c
1 /*
2 SSL client program that sets up a connection to a SSL server, transmits
3 the given input file and then writes the reply to the given output file.
4
5 Copyright (C) 2008, 2009 Red Hat Inc.
6
7 This file is part of systemtap, and is free software. You can
8 redistribute it and/or modify it under the terms of the GNU General Public
9 License as published by the Free Software Foundation; either version 2 of the
10 License, or (at your option) any later version.
11
12 This program is distributed in the hope that it will be useful,
13 but WITHOUT ANY WARRANTY; without even the implied warranty of
14 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15 GNU General Public License for more details.
16
17 You should have received a copy of the GNU General Public License
18 along with this program; if not, write to the Free Software
19 Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
20 */
21
22 #include <stdio.h>
23
24 #include <ssl.h>
25 #include <nspr.h>
26 #include <plgetopt.h>
27 #include <nss.h>
28 #include <prerror.h>
29 #include <secerr.h>
30 #include <sslerr.h>
31
32 #define READ_BUFFER_SIZE (60 * 1024)
33 static char *hostName = NULL;
34 static unsigned short port = 0;
35 static const char *infileName = NULL;
36 static const char *outfileName = NULL;
37
38 static void
39 Usage(const char *progName)
40 {
41 fprintf(stderr, "Usage: %s -h hostname -p port -d dbdir -i infile -o outfile\n",
42 progName);
43 exit(1);
44 }
45
46 static void
47 errWarn(char *function)
48 {
49 PRErrorCode errorNumber;
50 PRInt32 errorTextLength;
51 PRInt32 rc;
52 char *errorText;
53
54 errorNumber = PR_GetError();
55 fprintf(stderr, "Error in function %s: %d: ", function, errorNumber);
56
57 /* See if PR_GetErrorText can tell us what the error is. */
58 if (errorNumber >= PR_NSPR_ERROR_BASE && errorNumber <= PR_MAX_ERROR)
59 {
60 errorTextLength = PR_GetErrorTextLength ();
61 if (errorTextLength != 0) {
62 errorText = PORT_Alloc(errorTextLength);
63 rc = PR_GetErrorText (errorText);
64 if (rc != 0)
65 fprintf (stderr, "%s\n", errorText);
66 PR_Free (errorText);
67 if (rc != 0)
68 return;
69 }
70 }
71
72 /* Otherwise handle common errors ourselves. */
73 switch (errorNumber)
74 {
75 case SEC_ERROR_CA_CERT_INVALID:
76 fputs ("The issuer's certificate is invalid\n", stderr);
77 break;
78 case SEC_ERROR_BAD_DATABASE:
79 fputs ("The specified certificate database does not exist or is not valid\n", stderr);
80 break;
81 case SSL_ERROR_BAD_CERT_DOMAIN:
82 fputs ("The requested domain name does not match the server's certificate\n", stderr);
83 break;
84 case PR_CONNECT_RESET_ERROR:
85 fputs ("Connection reset by peer\n", stderr);
86 break;
87 default:
88 fputs ("Unknown error\n", stderr);
89 break;
90 }
91 }
92
93 static void
94 exitErr(char *function)
95 {
96 errWarn(function);
97 /* Exit gracefully. */
98 /* ignoring return value of NSS_Shutdown as code exits with 1*/
99 (void) NSS_Shutdown();
100 PR_Cleanup();
101 exit(1);
102 }
103
104 static PRFileDesc *
105 setupSSLSocket(void)
106 {
107 PRFileDesc *tcpSocket;
108 PRFileDesc *sslSocket;
109 PRSocketOptionData socketOption;
110 PRStatus prStatus;
111 SECStatus secStatus;
112
113 tcpSocket = PR_NewTCPSocket();
114 if (tcpSocket == NULL)
115 {
116 errWarn("PR_NewTCPSocket");
117 }
118
119 /* Make the socket blocking. */
120 socketOption.option = PR_SockOpt_Nonblocking;
121 socketOption.value.non_blocking = PR_FALSE;
122
123 prStatus = PR_SetSocketOption(tcpSocket, &socketOption);
124 if (prStatus != PR_SUCCESS)
125 {
126 errWarn("PR_SetSocketOption");
127 goto loser;
128 }
129
130 /* Import the socket into the SSL layer. */
131 sslSocket = SSL_ImportFD(NULL, tcpSocket);
132 if (!sslSocket)
133 {
134 errWarn("SSL_ImportFD");
135 goto loser;
136 }
137
138 /* Set configuration options. */
139 secStatus = SSL_OptionSet(sslSocket, SSL_SECURITY, PR_TRUE);
140 if (secStatus != SECSuccess)
141 {
142 errWarn("SSL_OptionSet:SSL_SECURITY");
143 goto loser;
144 }
145
146 secStatus = SSL_OptionSet(sslSocket, SSL_HANDSHAKE_AS_CLIENT, PR_TRUE);
147 if (secStatus != SECSuccess)
148 {
149 errWarn("SSL_OptionSet:SSL_HANDSHAKE_AS_CLIENT");
150 goto loser;
151 }
152
153 /* Set SSL callback routines. */
154 #if 0 /* no client authentication */
155 secStatus = SSL_GetClientAuthDataHook(sslSocket,
156 (SSLGetClientAuthData)myGetClientAuthData,
157 (void *)certNickname);
158 if (secStatus != SECSuccess)
159 {
160 errWarn("SSL_GetClientAuthDataHook");
161 goto loser;
162 }
163 #endif
164 #if 0 /* Use the default */
165 secStatus = SSL_AuthCertificateHook(sslSocket,
166 (SSLAuthCertificate)myAuthCertificate,
167 (void *)CERT_GetDefaultCertDB());
168 if (secStatus != SECSuccess)
169 {
170 errWarn("SSL_AuthCertificateHook");
171 goto loser;
172 }
173 #endif
174 #if 0 /* Use the default */
175 secStatus = SSL_BadCertHook(sslSocket,
176 (SSLBadCertHandler)myBadCertHandler, NULL);
177 if (secStatus != SECSuccess)
178 {
179 errWarn("SSL_BadCertHook");
180 goto loser;
181 }
182 #endif
183 #if 0 /* No handshake callback */
184 secStatus = SSL_HandshakeCallback(sslSocket, myHandshakeCallback, NULL);
185 if (secStatus != SECSuccess)
186 {
187 errWarn("SSL_HandshakeCallback");
188 goto loser;
189 }
190 #endif
191
192 return sslSocket;
193
194 loser:
195 PR_Close(tcpSocket);
196 return NULL;
197 }
198
199
200 static SECStatus
201 handle_connection(PRFileDesc *sslSocket)
202 {
203 #if DEBUG
204 int countRead = 0;
205 #endif
206 PRInt32 numBytes;
207 char *readBuffer;
208 PRFileInfo info;
209 PRFileDesc *local_file_fd;
210 PRStatus prStatus;
211
212 /* read and send the data. */
213 /* Try to open the local file named.
214 * If successful, then write it to the server
215 */
216 prStatus = PR_GetFileInfo(infileName, &info);
217 if (prStatus != PR_SUCCESS ||
218 info.type != PR_FILE_FILE ||
219 info.size < 0)
220 {
221 fprintf (stderr, "could not find input file %s\n", infileName);
222 return SECFailure;
223 }
224
225 local_file_fd = PR_Open(infileName, PR_RDONLY, 0);
226 if (local_file_fd == NULL)
227 {
228 fprintf (stderr, "could not open input file %s\n", infileName);
229 return SECFailure;
230 }
231
232 /* Send the file size first, so the server knows when it has the entire file. */
233 numBytes = PR_Write(sslSocket, & info.size, sizeof (info.size));
234 if (numBytes < 0)
235 {
236 errWarn("PR_Write");
237 return SECFailure;
238 }
239
240 /* Transmit the local file across the socket. */
241 numBytes = PR_TransmitFile(sslSocket, local_file_fd,
242 NULL, 0,
243 PR_TRANSMITFILE_KEEP_OPEN,
244 PR_INTERVAL_NO_TIMEOUT);
245 if (numBytes < 0)
246 {
247 errWarn("PR_TransmitFile");
248 return SECFailure;
249 }
250
251 #if DEBUG
252 /* Transmitted bytes successfully. */
253 fprintf(stderr, "PR_TransmitFile wrote %d bytes from %s\n",
254 numBytes, infileName);
255 #endif
256
257 PR_Close(local_file_fd);
258
259 /* read until EOF */
260 readBuffer = PORT_Alloc(READ_BUFFER_SIZE);
261 if (! readBuffer)
262 exitErr("PORT_Alloc");
263
264 local_file_fd = PR_Open(outfileName, PR_WRONLY | PR_CREATE_FILE | PR_TRUNCATE,
265 PR_IRUSR | PR_IWUSR | PR_IRGRP | PR_IWGRP | PR_IROTH);
266 if (local_file_fd == NULL)
267 {
268 fprintf (stderr, "could not open output file %s\n", outfileName);
269 return SECFailure;
270 }
271 while (PR_TRUE)
272 {
273 numBytes = PR_Read(sslSocket, readBuffer, READ_BUFFER_SIZE);
274 if (numBytes == 0)
275 break; /* EOF */
276
277 if (numBytes < 0)
278 {
279 errWarn("PR_Read");
280 break;
281 }
282 #if DEBUG
283 countRead += numBytes;
284 #endif
285 /* Write to output file */
286 numBytes = PR_Write(local_file_fd, readBuffer, numBytes);
287 if (numBytes < 0)
288 {
289 fprintf (stderr, "could not write to %s\n", outfileName);
290 break;
291 }
292 #if DEBUG
293 fprintf(stderr, "***** Connection read %d bytes (%d total).\n",
294 numBytes, countRead );
295 readBuffer[numBytes] = '\0';
296 fprintf(stderr, "************\n%s\n************\n", readBuffer);
297 #endif
298 }
299
300 PR_Free(readBuffer);
301 PR_Close(local_file_fd);
302
303 /* Caller closes the socket. */
304 #if DEBUG
305 fprintf(stderr, "***** Connection read %d bytes total.\n", countRead);
306 #endif
307
308 return SECSuccess;
309 }
310
311 /* make the connection.
312 */
313 static SECStatus
314 do_connect(PRNetAddr *addr)
315 {
316 PRFileDesc *sslSocket;
317 PRStatus prStatus;
318 #if 0
319 PRHostEnt hostEntry;
320 char buffer[PR_NETDB_BUF_SIZE];
321 PRIntn hostenum;
322 #endif
323 SECStatus secStatus;
324
325 secStatus = SECSuccess;
326
327 /* Set up SSL secure socket. */
328 sslSocket = setupSSLSocket();
329 if (sslSocket == NULL)
330 {
331 errWarn("setupSSLSocket");
332 return SECFailure;
333 }
334
335 #if 0 /* no client authentication */
336 secStatus = SSL_SetPKCS11PinArg(sslSocket, password);
337 if (secStatus != SECSuccess)
338 {
339 errWarn("SSL_SetPKCS11PinArg");
340 goto done;
341 }
342 #endif
343
344 secStatus = SSL_SetURL(sslSocket, hostName);
345 if (secStatus != SECSuccess)
346 {
347 errWarn("SSL_SetURL");
348 goto done;
349 }
350 #if 0 /* Already done? */
351 /* Prepare and setup network connection. */
352 prStatus = PR_GetHostByName(hostName, buffer, sizeof(buffer), &hostEntry);
353 if (prStatus != PR_SUCCESS)
354 {
355 errWarn("PR_GetHostByName");
356 secStatus = SECFailure;
357 goto done;
358 }
359
360 hostenum = PR_EnumerateHostEnt(0, &hostEntry, port, addr);
361 if (hostenum == -1)
362 {
363 errWarn("PR_EnumerateHostEnt");
364 secStatus = SECFailure;
365 goto done;
366 }
367 #endif
368 prStatus = PR_Connect(sslSocket, addr, PR_INTERVAL_NO_TIMEOUT);
369 if (prStatus != PR_SUCCESS)
370 {
371 errWarn("PR_Connect");
372 secStatus = SECFailure;
373 goto done;
374 }
375
376 /* Established SSL connection, ready to send data. */
377 secStatus = SSL_ResetHandshake(sslSocket, /* asServer */ PR_FALSE);
378 if (secStatus != SECSuccess)
379 {
380 errWarn("SSL_ResetHandshake");
381 goto done;
382 }
383
384 /* This is normally done automatically on the first I/O operation,
385 but doing it here catches any authentication problems early. */
386 secStatus = SSL_ForceHandshake(sslSocket);
387 if (secStatus != SECSuccess)
388 {
389 errWarn("SSL_ForceHandshake");
390 goto done;
391 }
392
393 secStatus = handle_connection(sslSocket);
394 if (secStatus != SECSuccess)
395 {
396 errWarn("handle_connection");
397 goto done;
398 }
399
400 done:
401 prStatus = PR_Close(sslSocket);
402 if (prStatus != PR_SUCCESS)
403 errWarn("PR_Close");
404
405 return secStatus;
406 }
407
408 static void
409 client_main(unsigned short port)
410 {
411 SECStatus secStatus;
412 PRStatus prStatus;
413 PRInt32 rv;
414 PRNetAddr addr;
415 PRHostEnt hostEntry;
416 char buffer[PR_NETDB_BUF_SIZE];
417
418 /* Setup network connection. */
419 prStatus = PR_GetHostByName(hostName, buffer, sizeof (buffer), &hostEntry);
420 if (prStatus != PR_SUCCESS)
421 exitErr("PR_GetHostByName");
422
423 rv = PR_EnumerateHostEnt(0, &hostEntry, port, &addr);
424 if (rv < 0)
425 exitErr("PR_EnumerateHostEnt");
426
427 secStatus = do_connect (&addr);
428 if (secStatus != SECSuccess)
429 exitErr("do_connect");
430 }
431
432 #if 0 /* No client authorization */
433 static char *
434 myPasswd(PK11SlotInfo *info, PRBool retry, void *arg)
435 {
436 char * passwd = NULL;
437
438 if ( (!retry) && arg )
439 passwd = PORT_Strdup((char *)arg);
440
441 return passwd;
442 }
443 #endif
444
445 int
446 main(int argc, char **argv)
447 {
448 char * certDir = NULL;
449 char * progName = NULL;
450 SECStatus secStatus;
451 PLOptState *optstate;
452 PLOptStatus status;
453
454 /* Call the NSPR initialization routines */
455 PR_Init( PR_SYSTEM_THREAD, PR_PRIORITY_NORMAL, 1);
456
457 progName = PL_strdup(argv[0]);
458
459 hostName = NULL;
460 optstate = PL_CreateOptState(argc, argv, "d:h:i:o:p:");
461 while ((status = PL_GetNextOpt(optstate)) == PL_OPT_OK)
462 {
463 switch(optstate->option)
464 {
465 case 'd' : certDir = PL_strdup(optstate->value); break;
466 case 'h' : hostName = PL_strdup(optstate->value); break;
467 case 'i' : infileName = PL_strdup(optstate->value); break;
468 case 'o' : outfileName = PL_strdup(optstate->value); break;
469 case 'p' : port = PORT_Atoi(optstate->value); break;
470 case '?' :
471 default : Usage(progName);
472 }
473 }
474
475 if (port == 0 || hostName == NULL || infileName == NULL || outfileName == NULL || certDir == NULL)
476 Usage(progName);
477
478 #if 0 /* no client authentication */
479 /* Set our password function callback. */
480 PK11_SetPasswordFunc(myPasswd);
481 #endif
482
483 /* Initialize the NSS libraries. */
484 secStatus = NSS_Init(certDir);
485 if (secStatus != SECSuccess)
486 exitErr("NSS_Init");
487
488 /* All cipher suites except RSA_NULL_MD5 are enabled by Domestic Policy. */
489 NSS_SetDomesticPolicy();
490
491 client_main(port);
492
493 NSS_Shutdown();
494 PR_Cleanup();
495
496 return 0;
497 }
498
499 /* vim: set sw=2 ts=8 cino=>4,n-2,{2,^-2,t0,(0,u0,w1,M1 : */
This page took 0.060251 seconds and 6 git commands to generate.