View Javadoc
1   /*
2    * ====================================================================
3    * Licensed to the Apache Software Foundation (ASF) under one
4    * or more contributor license agreements.  See the NOTICE file
5    * distributed with this work for additional information
6    * regarding copyright ownership.  The ASF licenses this file
7    * to you under the Apache License, Version 2.0 (the
8    * "License"); you may not use this file except in compliance
9    * with the License.  You may obtain a copy of the License at
10   *
11   *   http://www.apache.org/licenses/LICENSE-2.0
12   *
13   * Unless required by applicable law or agreed to in writing,
14   * software distributed under the License is distributed on an
15   * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16   * KIND, either express or implied.  See the License for the
17   * specific language governing permissions and limitations
18   * under the License.
19   * ====================================================================
20   *
21   * This software consists of voluntary contributions made by many
22   * individuals on behalf of the Apache Software Foundation.  For more
23   * information on the Apache Software Foundation, please see
24   * <http://www.apache.org/>.
25   *
26   */
27  
28  package org.apache.http.impl.auth;
29  
30  
31  import java.nio.ByteBuffer;
32  import java.nio.charset.Charset;
33  import java.security.KeyManagementException;
34  import java.security.NoSuchAlgorithmException;
35  import java.security.PublicKey;
36  import java.security.cert.Certificate;
37  import java.security.cert.CertificateException;
38  import java.security.cert.X509Certificate;
39  import java.util.Arrays;
40  
41  import javax.net.ssl.SSLContext;
42  import javax.net.ssl.SSLEngine;
43  import javax.net.ssl.SSLEngineResult;
44  import javax.net.ssl.SSLEngineResult.HandshakeStatus;
45  import javax.net.ssl.SSLEngineResult.Status;
46  import javax.net.ssl.SSLException;
47  import javax.net.ssl.SSLPeerUnverifiedException;
48  import javax.net.ssl.SSLSession;
49  import javax.net.ssl.TrustManager;
50  import javax.net.ssl.X509TrustManager;
51  
52  import org.apache.commons.codec.binary.Base64;
53  import org.apache.commons.logging.Log;
54  import org.apache.commons.logging.LogFactory;
55  import org.apache.http.Consts;
56  import org.apache.http.Header;
57  import org.apache.http.HttpRequest;
58  import org.apache.http.auth.AUTH;
59  import org.apache.http.auth.AuthenticationException;
60  import org.apache.http.auth.Credentials;
61  import org.apache.http.auth.InvalidCredentialsException;
62  import org.apache.http.auth.MalformedChallengeException;
63  import org.apache.http.auth.NTCredentials;
64  import org.apache.http.message.BufferedHeader;
65  import org.apache.http.protocol.HttpContext;
66  import org.apache.http.ssl.SSLContexts;
67  import org.apache.http.util.CharArrayBuffer;
68  import org.apache.http.util.CharsetUtils;
69  
70  /**
71   * <p>
72   * Client implementation of the CredSSP protocol specified in [MS-CSSP].
73   * </p>
74   * <p>
75   * Note: This is implementation is NOT GSS based. It should be. But there is no Java NTLM
76   * implementation as GSS module. Maybe the NTLMEngine can be converted to GSS and then this
77   * can be also switched to GSS. In fact it only works in CredSSP+NTLM case.
78   * </p>
79   * <p>
80   * Based on [MS-CSSP]: Credential Security Support Provider (CredSSP) Protocol (Revision 13.0, 7/14/2016).
81   * The implementation was inspired by Python CredSSP and NTLM implementation by Jordan Borean.
82   * </p>
83   */
84  public class CredSspScheme extends AuthSchemeBase
85  {
86      private static final Charset UNICODE_LITTLE_UNMARKED = CharsetUtils.lookup( "UnicodeLittleUnmarked" );
87      public static final String SCHEME_NAME = "CredSSP";
88  
89      private final Log log = LogFactory.getLog( CredSspScheme.class );
90  
91      enum State
92      {
93          // Nothing sent, nothing received
94          UNINITIATED,
95  
96          // We are handshaking. Several messages are exchanged in this state
97          TLS_HANDSHAKE,
98  
99          // TLS handshake finished. Channel established
100         TLS_HANDSHAKE_FINISHED,
101 
102         // NTLM NEGOTIATE message sent (strictly speaking this should be SPNEGO)
103         NEGO_TOKEN_SENT,
104 
105         // NTLM CHALLENGE message received  (strictly speaking this should be SPNEGO)
106         NEGO_TOKEN_RECEIVED,
107 
108         // NTLM AUTHENTICATE message sent together with a server public key
109         PUB_KEY_AUTH_SENT,
110 
111         // Server public key authentication message received
112         PUB_KEY_AUTH_RECEIVED,
113 
114         // Credentials message sent. Protocol exchange finished.
115         CREDENTIALS_SENT;
116     }
117 
118     private State state;
119     private SSLEngine sslEngine;
120     private NTLMEngineImpl.Type1Message type1Message;
121     private NTLMEngineImpl.Type2Message type2Message;
122     private NTLMEngineImpl.Type3Message type3Message;
123     private CredSspTsRequest lastReceivedTsRequest;
124     private NTLMEngineImpl.Handle ntlmOutgoingHandle;
125     private NTLMEngineImpl.Handle ntlmIncomingHandle;
126     private byte[] peerPublicKey;
127 
128 
129     public CredSspScheme() {
130         state = State.UNINITIATED;
131     }
132 
133 
134     @Override
135     public String getSchemeName()
136     {
137         return SCHEME_NAME;
138     }
139 
140 
141     @Override
142     public String getParameter( final String name )
143     {
144         return null;
145     }
146 
147 
148     @Override
149     public String getRealm()
150     {
151         return null;
152     }
153 
154 
155     @Override
156     public boolean isConnectionBased()
157     {
158         return true;
159     }
160 
161 
162     private SSLEngine getSSLEngine()
163     {
164         if ( sslEngine == null )
165         {
166             sslEngine = createSSLEngine();
167         }
168         return sslEngine;
169     }
170 
171 
172     private SSLEngine createSSLEngine()
173     {
174         final SSLContext sslContext;
175         try
176         {
177             sslContext = SSLContexts.custom().build();
178         }
179         catch ( final NoSuchAlgorithmException e )
180         {
181             throw new RuntimeException( "Error creating SSL Context: " + e.getMessage(), e );
182         }
183         catch ( final KeyManagementException e )
184         {
185             throw new RuntimeException( "Error creating SSL Context: " + e.getMessage(), e );
186         }
187 
188         final X509TrustManager tm = new X509TrustManager()
189         {
190 
191             @Override
192             public void checkClientTrusted( final X509Certificate[] chain, final String authType )
193                 throws CertificateException
194             {
195                 // Nothing to do.
196             }
197 
198 
199             @Override
200             public void checkServerTrusted( final X509Certificate[] chain, final String authType )
201                 throws CertificateException
202             {
203                 // Nothing to do, accept all. CredSSP server is using its own certificate without any
204                 // binding to the PKI trust chains. The public key is verified as part of the CredSSP
205                 // protocol exchange.
206             }
207 
208 
209             @Override
210             public X509Certificate[] getAcceptedIssuers()
211             {
212                 return null;
213             }
214 
215         };
216         try
217         {
218             sslContext.init( null, new TrustManager[]
219                 { tm }, null );
220         }
221         catch ( final KeyManagementException e )
222         {
223             throw new RuntimeException( "SSL Context initialization error: " + e.getMessage(), e );
224         }
225         final SSLEngine sslEngine = sslContext.createSSLEngine();
226         sslEngine.setUseClientMode( true );
227         return sslEngine;
228     }
229 
230 
231     @Override
232     protected void parseChallenge( final CharArrayBuffer buffer, final int beginIndex, final int endIndex )
233         throws MalformedChallengeException
234     {
235         final String inputString = buffer.substringTrimmed( beginIndex, endIndex );
236 
237         if ( inputString.isEmpty() )
238         {
239             if ( state == State.UNINITIATED )
240             {
241                 // This is OK, just send out first message. That should start TLS handshake
242             }
243             else
244             {
245                 final String msg = "Received unexpected empty input in state " + state;
246                 log.error( msg );
247                 throw new MalformedChallengeException( msg );
248             }
249         }
250 
251         if ( state == State.TLS_HANDSHAKE )
252         {
253             unwrapHandshake( inputString );
254             if ( getSSLEngine().getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING )
255             {
256                 log.trace( "TLS handshake finished" );
257                 state = State.TLS_HANDSHAKE_FINISHED;
258             }
259         }
260 
261         if ( state == State.NEGO_TOKEN_SENT )
262         {
263             final ByteBuffer buf = unwrap( inputString );
264             state = State.NEGO_TOKEN_RECEIVED;
265             lastReceivedTsRequest = CredSspTsRequest.createDecoded( buf );
266         }
267 
268         if ( state == State.PUB_KEY_AUTH_SENT )
269         {
270             final ByteBuffer buf = unwrap( inputString );
271             state = State.PUB_KEY_AUTH_RECEIVED;
272             lastReceivedTsRequest = CredSspTsRequest.createDecoded( buf );
273         }
274     }
275 
276 
277     @Override
278     @Deprecated
279     public Header authenticate(
280         final Credentials credentials,
281         final HttpRequest request ) throws AuthenticationException
282     {
283         return authenticate( credentials, request, null );
284     }
285 
286 
287     @Override
288     public Header authenticate(
289         final Credentials credentials,
290         final HttpRequest request,
291         final HttpContext context ) throws AuthenticationException
292     {
293         NTCredentials ntcredentials = null;
294         try
295         {
296             ntcredentials = ( NTCredentials ) credentials;
297         }
298         catch ( final ClassCastException e )
299         {
300             throw new InvalidCredentialsException(
301                 "Credentials cannot be used for CredSSP authentication: "
302                     + credentials.getClass().getName() );
303         }
304 
305         String outputString = null;
306 
307         if ( state == State.UNINITIATED )
308         {
309             beginTlsHandshake();
310             outputString = wrapHandshake();
311             state = State.TLS_HANDSHAKE;
312 
313         }
314         else if ( state == State.TLS_HANDSHAKE )
315         {
316             outputString = wrapHandshake();
317 
318         }
319         else if ( state == State.TLS_HANDSHAKE_FINISHED )
320         {
321 
322             final int ntlmFlags = getNtlmFlags();
323             final ByteBuffer buf = allocateOutBuffer();
324             type1Message = new NTLMEngineImpl.Type1Message(
325                 ntcredentials.getDomain(), ntcredentials.getWorkstation(), ntlmFlags);
326             final byte[] ntlmNegoMessageEncoded = type1Message.getBytes();
327             final CredSspTsRequest req = CredSspTsRequest.createNegoToken( ntlmNegoMessageEncoded );
328             req.encode( buf );
329             buf.flip();
330             outputString = wrap( buf );
331             state = State.NEGO_TOKEN_SENT;
332 
333         }
334         else if ( state == State.NEGO_TOKEN_RECEIVED )
335         {
336             final ByteBuffer buf = allocateOutBuffer();
337             type2Message = new NTLMEngineImpl.Type2Message(
338                 lastReceivedTsRequest.getNegoToken());
339 
340             final Certificate peerServerCertificate = getPeerServerCertificate();
341 
342             type3Message = new NTLMEngineImpl.Type3Message(
343                 ntcredentials.getDomain(),
344                 ntcredentials.getWorkstation(),
345                 ntcredentials.getUserName(),
346                 ntcredentials.getPassword(),
347                 type2Message.getChallenge(),
348                 type2Message.getFlags(),
349                 type2Message.getTarget(),
350                 type2Message.getTargetInfo(),
351                 peerServerCertificate,
352                 type1Message.getBytes(),
353                 type2Message.getBytes());
354 
355             final byte[] ntlmAuthenticateMessageEncoded = type3Message.getBytes();
356 
357             final byte[] exportedSessionKey = type3Message.getExportedSessionKey();
358 
359             ntlmOutgoingHandle = new NTLMEngineImpl.Handle(exportedSessionKey, NTLMEngineImpl.Mode.CLIENT, true);
360             ntlmIncomingHandle = new NTLMEngineImpl.Handle(exportedSessionKey, NTLMEngineImpl.Mode.SERVER, true);
361 
362             final CredSspTsRequest req = CredSspTsRequest.createNegoToken( ntlmAuthenticateMessageEncoded );
363             peerPublicKey = getSubjectPublicKeyDer( peerServerCertificate.getPublicKey() );
364             final byte[] pubKeyAuth = createPubKeyAuth();
365             req.setPubKeyAuth( pubKeyAuth );
366 
367             req.encode( buf );
368             buf.flip();
369             outputString = wrap( buf );
370             state = State.PUB_KEY_AUTH_SENT;
371 
372         }
373         else if ( state == State.PUB_KEY_AUTH_RECEIVED )
374         {
375             verifyPubKeyAuthResponse( lastReceivedTsRequest.getPubKeyAuth() );
376             final byte[] authInfo = createAuthInfo( ntcredentials );
377             final CredSspTsRequest req = CredSspTsRequest.createAuthInfo( authInfo );
378 
379             final ByteBuffer buf = allocateOutBuffer();
380             req.encode( buf );
381             buf.flip();
382             outputString = wrap( buf );
383             state = State.CREDENTIALS_SENT;
384         }
385         else
386         {
387             throw new AuthenticationException( "Wrong state " + state );
388         }
389         final CharArrayBuffer buffer = new CharArrayBuffer( 32 );
390         if ( isProxy() )
391         {
392             buffer.append( AUTH.PROXY_AUTH_RESP );
393         }
394         else
395         {
396             buffer.append( AUTH.WWW_AUTH_RESP );
397         }
398         buffer.append( ": CredSSP " );
399         buffer.append( outputString );
400         return new BufferedHeader( buffer );
401     }
402 
403 
404     private int getNtlmFlags()
405     {
406         return NTLMEngineImpl.FLAG_REQUEST_OEM_ENCODING |
407             NTLMEngineImpl.FLAG_REQUEST_SIGN |
408             NTLMEngineImpl.FLAG_REQUEST_SEAL |
409             NTLMEngineImpl.FLAG_DOMAIN_PRESENT |
410             NTLMEngineImpl.FLAG_REQUEST_ALWAYS_SIGN |
411             NTLMEngineImpl.FLAG_REQUEST_NTLM2_SESSION |
412             NTLMEngineImpl.FLAG_TARGETINFO_PRESENT |
413             NTLMEngineImpl.FLAG_REQUEST_VERSION |
414             NTLMEngineImpl.FLAG_REQUEST_128BIT_KEY_EXCH |
415             NTLMEngineImpl.FLAG_REQUEST_EXPLICIT_KEY_EXCH |
416             NTLMEngineImpl.FLAG_REQUEST_56BIT_ENCRYPTION;
417     }
418 
419 
420     private Certificate getPeerServerCertificate() throws AuthenticationException
421     {
422         final Certificate[] peerCertificates;
423         try
424         {
425             peerCertificates = sslEngine.getSession().getPeerCertificates();
426         }
427         catch ( final SSLPeerUnverifiedException e )
428         {
429             throw new AuthenticationException( e.getMessage(), e );
430         }
431         for ( final Certificate peerCertificate : peerCertificates )
432         {
433             if ( !( peerCertificate instanceof X509Certificate ) )
434             {
435                 continue;
436             }
437             final X509Certificate peerX509Cerificate = ( X509Certificate ) peerCertificate;
438             if ( peerX509Cerificate.getBasicConstraints() != -1 )
439             {
440                 continue;
441             }
442             return peerX509Cerificate;
443         }
444         return null;
445     }
446 
447 
448     private byte[] createPubKeyAuth() throws AuthenticationException
449     {
450         return ntlmOutgoingHandle.signAndEncryptMessage( peerPublicKey );
451     }
452 
453 
454     private void verifyPubKeyAuthResponse( final byte[] pubKeyAuthResponse ) throws AuthenticationException
455     {
456         final byte[] pubKeyReceived = ntlmIncomingHandle.decryptAndVerifySignedMessage( pubKeyAuthResponse );
457 
458         // assert: pubKeyReceived = peerPublicKey + 1
459         // The following algorithm is a bit simplified. But due to the ASN.1 encoding the first byte
460         // of the public key will be 0x30 we can pretty much rely on a fact that there will be no carry
461         if ( peerPublicKey.length != pubKeyReceived.length )
462         {
463             throw new AuthenticationException( "Public key mismatch in pubKeyAuth response" );
464         }
465         if ( ( peerPublicKey[0] + 1 ) != pubKeyReceived[0] )
466         {
467             throw new AuthenticationException( "Public key mismatch in pubKeyAuth response" );
468         }
469         for ( int i = 1; i < peerPublicKey.length; i++ )
470         {
471             if ( peerPublicKey[i] != pubKeyReceived[i] )
472             {
473                 throw new AuthenticationException( "Public key mismatch in pubKeyAuth response" );
474             }
475         }
476         log.trace( "Received public key response is valid" );
477     }
478 
479 
480     private byte[] createAuthInfo( final NTCredentials ntcredentials ) throws AuthenticationException
481     {
482 
483         final byte[] domainBytes = encodeUnicode( ntcredentials.getDomain() );
484         final byte[] domainOctetStringBytesLengthBytes = encodeLength( domainBytes.length );
485         final int domainNameLength = 1 + domainOctetStringBytesLengthBytes.length + domainBytes.length;
486         final byte[] domainNameLengthBytes = encodeLength( domainNameLength );
487 
488         final byte[] usernameBytes = encodeUnicode( ntcredentials.getUserName() );
489         final byte[] usernameOctetStringBytesLengthBytes = encodeLength( usernameBytes.length );
490         final int userNameLength = 1 + usernameOctetStringBytesLengthBytes.length + usernameBytes.length;
491         final byte[] userNameLengthBytes = encodeLength( userNameLength );
492 
493         final byte[] passwordBytes = encodeUnicode( ntcredentials.getPassword() );
494         final byte[] passwordOctetStringBytesLengthBytes = encodeLength( passwordBytes.length );
495         final int passwordLength = 1 + passwordOctetStringBytesLengthBytes.length + passwordBytes.length;
496         final byte[] passwordLengthBytes = encodeLength( passwordLength );
497 
498         final int tsPasswordLength = 1 + domainNameLengthBytes.length + domainNameLength +
499             1 + userNameLengthBytes.length + userNameLength +
500             1 + passwordLengthBytes.length + passwordLength;
501         final byte[] tsPasswordLengthBytes = encodeLength( tsPasswordLength );
502         final int credentialsOctetStringLength = 1 + tsPasswordLengthBytes.length + tsPasswordLength;
503         final byte[] credentialsOctetStringLengthBytes = encodeLength( credentialsOctetStringLength );
504         final int credentialsLength = 1 + credentialsOctetStringLengthBytes.length + credentialsOctetStringLength;
505         final byte[] credentialsLengthBytes = encodeLength( credentialsLength );
506         final int tsCredentialsLength = 5 + 1 + credentialsLengthBytes.length + credentialsLength;
507         final byte[] tsCredentialsLengthBytes = encodeLength( tsCredentialsLength );
508 
509         final ByteBuffer buf = ByteBuffer.allocate( 1 + tsCredentialsLengthBytes.length + tsCredentialsLength );
510 
511         // TSCredentials structure [MS-CSSP] section 2.2.1.2
512         buf.put( ( byte ) 0x30 ); // seq
513         buf.put( tsCredentialsLengthBytes );
514 
515         buf.put( ( byte ) ( 0x00 | 0xa0 ) ); // credType tag [0]
516         buf.put( ( byte ) 3 ); // credType length
517         buf.put( ( byte ) 0x02 ); // type: INTEGER
518         buf.put( ( byte ) 1 ); // credType inner length
519         buf.put( ( byte ) 1 ); // credType value: 1 (password)
520 
521         buf.put( ( byte ) ( 0x01 | 0xa0 ) ); // credentials tag [1]
522         buf.put( credentialsLengthBytes );
523         buf.put( ( byte ) 0x04 ); // type: OCTET STRING
524         buf.put( credentialsOctetStringLengthBytes );
525 
526         // TSPasswordCreds structure [MS-CSSP] section 2.2.1.2.1
527         buf.put( ( byte ) 0x30 ); // seq
528         buf.put( tsPasswordLengthBytes );
529 
530         buf.put( ( byte ) ( 0x00 | 0xa0 ) ); // domainName tag [0]
531         buf.put( domainNameLengthBytes );
532         buf.put( ( byte ) 0x04 ); // type: OCTET STRING
533         buf.put( domainOctetStringBytesLengthBytes );
534         buf.put( domainBytes );
535 
536         buf.put( ( byte ) ( 0x01 | 0xa0 ) ); // userName tag [1]
537         buf.put( userNameLengthBytes );
538         buf.put( ( byte ) 0x04 ); // type: OCTET STRING
539         buf.put( usernameOctetStringBytesLengthBytes );
540         buf.put( usernameBytes );
541 
542         buf.put( ( byte ) ( 0x02 | 0xa0 ) ); // password tag [2]
543         buf.put( passwordLengthBytes );
544         buf.put( ( byte ) 0x04 ); // type: OCTET STRING
545         buf.put( passwordOctetStringBytesLengthBytes );
546         buf.put( passwordBytes );
547 
548         final byte[] authInfo = buf.array();
549         try
550         {
551             return ntlmOutgoingHandle.signAndEncryptMessage( authInfo );
552         }
553         catch ( final NTLMEngineException e )
554         {
555             throw new AuthenticationException( e.getMessage(), e );
556         }
557     }
558 
559     private final static byte[] EMPTYBUFFER = new byte[0];
560 
561     private byte[] encodeUnicode( final String string )
562     {
563         if (string == null) {
564             return EMPTYBUFFER;
565         }
566         return string.getBytes( UNICODE_LITTLE_UNMARKED );
567     }
568 
569 
570     private byte[] getSubjectPublicKeyDer( final PublicKey publicKey ) throws AuthenticationException
571     {
572         // The publicKey.getEncoded() returns encoded SubjectPublicKeyInfo structure. But the CredSSP expects
573         // SubjectPublicKey subfield. I have found no easy way how to get just the SubjectPublicKey from
574         // java.security libraries. So let's use a primitive way and parse it out from the DER.
575 
576         try
577         {
578             final byte[] encodedPubKeyInfo = publicKey.getEncoded();
579 
580             final ByteBuffer buf = ByteBuffer.wrap( encodedPubKeyInfo );
581             getByteAndAssert( buf, 0x30, "initial sequence" );
582             parseLength( buf );
583             getByteAndAssert( buf, 0x30, "AlgorithmIdentifier sequence" );
584             final int algIdSeqLength = parseLength( buf );
585             buf.position( buf.position() + algIdSeqLength );
586             getByteAndAssert( buf, 0x03, "subjectPublicKey type" );
587             int subjectPublicKeyLegth = parseLength( buf );
588             // There may be leading padding byte ... or whatever that is. Skip that.
589             final byte b = buf.get();
590             if ( b == 0 )
591             {
592                 subjectPublicKeyLegth--;
593             }
594             else
595             {
596                 buf.position( buf.position() - 1 );
597             }
598             final byte[] subjectPublicKey = new byte[subjectPublicKeyLegth];
599             buf.get( subjectPublicKey );
600             return subjectPublicKey;
601         }
602         catch ( final MalformedChallengeException e )
603         {
604             throw new AuthenticationException( e.getMessage(), e );
605         }
606     }
607 
608 
609     private void beginTlsHandshake() throws AuthenticationException
610     {
611         try
612         {
613             getSSLEngine().beginHandshake();
614         }
615         catch ( final SSLException e )
616         {
617             throw new AuthenticationException( "SSL Engine error: " + e.getMessage(), e );
618         }
619     }
620 
621 
622     private ByteBuffer allocateOutBuffer()
623     {
624         final SSLEngine sslEngine = getSSLEngine();
625         final SSLSession sslSession = sslEngine.getSession();
626         return ByteBuffer.allocate( sslSession.getApplicationBufferSize() );
627     }
628 
629 
630     private String wrapHandshake() throws AuthenticationException
631     {
632         final ByteBuffer src = allocateOutBuffer();
633         src.flip();
634         final SSLEngine sslEngine = getSSLEngine();
635         final SSLSession sslSession = sslEngine.getSession();
636         // Needs to be twice the size as there may be two wraps during handshake.
637         // Primitive and inefficient solution, but it works.
638         final ByteBuffer dst = ByteBuffer.allocate( sslSession.getPacketBufferSize() * 2 );
639         while ( sslEngine.getHandshakeStatus() == HandshakeStatus.NEED_WRAP )
640         {
641             wrap( src, dst );
642         }
643         dst.flip();
644         return encodeBase64( dst );
645     }
646 
647 
648     private String wrap( final ByteBuffer src ) throws AuthenticationException
649     {
650         final SSLEngine sslEngine = getSSLEngine();
651         final SSLSession sslSession = sslEngine.getSession();
652         final ByteBuffer dst = ByteBuffer.allocate( sslSession.getPacketBufferSize() );
653         wrap( src, dst );
654         dst.flip();
655         return encodeBase64( dst );
656     }
657 
658 
659     private void wrap( final ByteBuffer src, final ByteBuffer dst ) throws AuthenticationException
660     {
661         final SSLEngine sslEngine = getSSLEngine();
662         try
663         {
664             final SSLEngineResult engineResult = sslEngine.wrap( src, dst );
665             if ( engineResult.getStatus() != Status.OK )
666             {
667                 throw new AuthenticationException( "SSL Engine error status: " + engineResult.getStatus() );
668             }
669         }
670         catch ( final SSLException e )
671         {
672             throw new AuthenticationException( "SSL Engine wrap error: " + e.getMessage(), e );
673         }
674     }
675 
676 
677     private void unwrapHandshake( final String inputString ) throws MalformedChallengeException
678     {
679         final SSLEngine sslEngine = getSSLEngine();
680         final SSLSession sslSession = sslEngine.getSession();
681         final ByteBuffer src = decodeBase64( inputString );
682         final ByteBuffer dst = ByteBuffer.allocate( sslSession.getApplicationBufferSize() );
683         while ( sslEngine.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP )
684         {
685             unwrap( src, dst );
686         }
687     }
688 
689 
690     private ByteBuffer unwrap( final String inputString ) throws MalformedChallengeException
691     {
692         final SSLEngine sslEngine = getSSLEngine();
693         final SSLSession sslSession = sslEngine.getSession();
694         final ByteBuffer src = decodeBase64( inputString );
695         final ByteBuffer dst = ByteBuffer.allocate( sslSession.getApplicationBufferSize() );
696         unwrap( src, dst );
697         dst.flip();
698         return dst;
699     }
700 
701 
702     private void unwrap( final ByteBuffer src, final ByteBuffer dst ) throws MalformedChallengeException
703     {
704 
705         try
706         {
707             final SSLEngineResult engineResult = sslEngine.unwrap( src, dst );
708             if ( engineResult.getStatus() != Status.OK )
709             {
710                 throw new MalformedChallengeException( "SSL Engine error status: " + engineResult.getStatus() );
711             }
712 
713             if ( sslEngine.getHandshakeStatus() == HandshakeStatus.NEED_TASK )
714             {
715                 final Runnable task = sslEngine.getDelegatedTask();
716                 task.run();
717             }
718 
719         }
720         catch ( final SSLException e )
721         {
722             throw new MalformedChallengeException( "SSL Engine unwrap error: " + e.getMessage(), e );
723         }
724     }
725 
726 
727     private String encodeBase64( final ByteBuffer buffer )
728     {
729         final int limit = buffer.limit();
730         final byte[] bytes = new byte[limit];
731         buffer.get( bytes );
732         return new String(Base64.encodeBase64(bytes), Consts.ASCII);
733     }
734 
735 
736     private ByteBuffer decodeBase64( final String inputString )
737     {
738         final byte[] inputBytes = Base64.decodeBase64(inputString.getBytes(Consts.ASCII));
739         final ByteBuffer buffer = ByteBuffer.wrap( inputBytes );
740         return buffer;
741     }
742 
743 
744     @Override
745     public boolean isComplete()
746     {
747         return state == State.CREDENTIALS_SENT;
748     }
749 
750     /**
751      * Implementation of the TsRequest structure used in CredSSP protocol.
752      * It is specified in [MS-CPPS] section 2.2.1.
753      */
754     static class CredSspTsRequest
755     {
756 
757         private static final int VERSION = 3;
758 
759         private byte[] negoToken;
760         private byte[] authInfo;
761         private byte[] pubKeyAuth;
762 
763 
764         protected CredSspTsRequest()
765         {
766             super();
767         }
768 
769 
770         public static CredSspTsRequest createNegoToken( final byte[] negoToken )
771         {
772             final CredSspTsRequest req = new CredSspTsRequest();
773             req.negoToken = negoToken;
774             return req;
775         }
776 
777 
778         public static CredSspTsRequest createAuthInfo( final byte[] authInfo )
779         {
780             final CredSspTsRequest req = new CredSspTsRequest();
781             req.authInfo = authInfo;
782             return req;
783         }
784 
785 
786         public static CredSspTsRequest createDecoded( final ByteBuffer buf ) throws MalformedChallengeException
787         {
788             final CredSspTsRequest req = new CredSspTsRequest();
789             req.decode( buf );
790             return req;
791         }
792 
793 
794         public byte[] getNegoToken()
795         {
796             return negoToken;
797         }
798 
799 
800         public void setNegoToken( final byte[] negoToken )
801         {
802             this.negoToken = negoToken;
803         }
804 
805 
806         public byte[] getAuthInfo()
807         {
808             return authInfo;
809         }
810 
811 
812         public void setAuthInfo( final byte[] authInfo )
813         {
814             this.authInfo = authInfo;
815         }
816 
817 
818         public byte[] getPubKeyAuth()
819         {
820             return pubKeyAuth;
821         }
822 
823 
824         public void setPubKeyAuth( final byte[] pubKeyAuth )
825         {
826             this.pubKeyAuth = pubKeyAuth;
827         }
828 
829 
830         public void decode( final ByteBuffer buf ) throws MalformedChallengeException
831         {
832             negoToken = null;
833             authInfo = null;
834             pubKeyAuth = null;
835 
836             getByteAndAssert( buf, 0x30, "initial sequence" );
837             parseLength( buf );
838 
839             while ( buf.hasRemaining() )
840             {
841                 final int contentTag = getAndAssertContentSpecificTag( buf, "content tag" );
842                 parseLength( buf );
843                 switch ( contentTag )
844                 {
845                     case 0:
846                         processVersion( buf );
847                         break;
848                     case 1:
849                         parseNegoTokens( buf );
850                         break;
851                     case 2:
852                         parseAuthInfo( buf );
853                         break;
854                     case 3:
855                         parsePubKeyAuth( buf );
856                         break;
857                     case 4:
858                         processErrorCode( buf );
859                         break;
860                     default:
861                         parseError( buf, "unexpected content tag " + contentTag );
862                 }
863             }
864         }
865 
866 
867         private void processVersion( final ByteBuffer buf ) throws MalformedChallengeException
868         {
869             getByteAndAssert( buf, 0x02, "version type" );
870             getLengthAndAssert( buf, 1, "version length" );
871             getByteAndAssert( buf, VERSION, "wrong protocol version" );
872         }
873 
874 
875         private void parseNegoTokens( final ByteBuffer buf ) throws MalformedChallengeException
876         {
877             getByteAndAssert( buf, 0x30, "negoTokens sequence" );
878             parseLength( buf );
879             // I have seen both 0x30LL encoding and 0x30LL0x30LL encoding. Accept both.
880             byte bufByte = buf.get();
881             if ( bufByte == 0x30 )
882             {
883                 parseLength( buf );
884                 bufByte = buf.get();
885             }
886             if ( ( bufByte & 0xff ) != 0xa0 )
887             {
888                 parseError( buf, "negoTokens: wrong content-specific tag " + String.format( "%02X", bufByte ) );
889             }
890             parseLength( buf );
891             getByteAndAssert( buf, 0x04, "negoToken type" );
892 
893             final int tokenLength = parseLength( buf );
894             negoToken = new byte[tokenLength];
895             buf.get( negoToken );
896         }
897 
898 
899         private void parseAuthInfo( final ByteBuffer buf ) throws MalformedChallengeException
900         {
901             getByteAndAssert( buf, 0x04, "authInfo type" );
902             final int length = parseLength( buf );
903             authInfo = new byte[length];
904             buf.get( authInfo );
905         }
906 
907 
908         private void parsePubKeyAuth( final ByteBuffer buf ) throws MalformedChallengeException
909         {
910             getByteAndAssert( buf, 0x04, "pubKeyAuth type" );
911             final int length = parseLength( buf );
912             pubKeyAuth = new byte[length];
913             buf.get( pubKeyAuth );
914         }
915 
916 
917         private void processErrorCode( final ByteBuffer buf ) throws MalformedChallengeException
918         {
919             getLengthAndAssert( buf, 3, "error code length" );
920             getByteAndAssert( buf, 0x02, "error code type" );
921             getLengthAndAssert( buf, 1, "error code length" );
922             final byte errorCode = buf.get();
923             parseError( buf, "Error code " + errorCode );
924         }
925 
926 
927         public void encode( final ByteBuffer buf )
928         {
929             final ByteBuffer inner = ByteBuffer.allocate( buf.capacity() );
930 
931             // version tag [0]
932             inner.put( ( byte ) ( 0x00 | 0xa0 ) );
933             inner.put( ( byte ) 3 ); // length
934 
935             inner.put( ( byte ) ( 0x02 ) ); // INTEGER tag
936             inner.put( ( byte ) 1 ); // length
937             inner.put( ( byte ) VERSION ); // value
938 
939             if ( negoToken != null )
940             {
941                 int len = negoToken.length;
942                 final byte[] negoTokenLengthBytes = encodeLength( len );
943                 len += 1 + negoTokenLengthBytes.length;
944                 final byte[] negoTokenLength1Bytes = encodeLength( len );
945                 len += 1 + negoTokenLength1Bytes.length;
946                 final byte[] negoTokenLength2Bytes = encodeLength( len );
947                 len += 1 + negoTokenLength2Bytes.length;
948                 final byte[] negoTokenLength3Bytes = encodeLength( len );
949                 len += 1 + negoTokenLength3Bytes.length;
950                 final byte[] negoTokenLength4Bytes = encodeLength( len );
951 
952                 inner.put( ( byte ) ( 0x01 | 0xa0 ) ); // negoData tag [1]
953                 inner.put( negoTokenLength4Bytes ); // length
954 
955                 inner.put( ( byte ) ( 0x30 ) ); // SEQUENCE tag
956                 inner.put( negoTokenLength3Bytes ); // length
957 
958                 inner.put( ( byte ) ( 0x30 ) ); // .. of SEQUENCE tag
959                 inner.put( negoTokenLength2Bytes ); // length
960 
961                 inner.put( ( byte ) ( 0x00 | 0xa0 ) ); // negoToken tag [0]
962                 inner.put( negoTokenLength1Bytes ); // length
963 
964                 inner.put( ( byte ) ( 0x04 ) ); // OCTET STRING tag
965                 inner.put( negoTokenLengthBytes ); // length
966 
967                 inner.put( negoToken );
968             }
969 
970             if ( authInfo != null )
971             {
972                 final byte[] authInfoEncodedLength = encodeLength( authInfo.length );
973 
974                 inner.put( ( byte ) ( 0x02 | 0xa0 ) ); // authInfo tag [2]
975                 inner.put( encodeLength( 1 + authInfoEncodedLength.length + authInfo.length ) ); // length
976 
977                 inner.put( ( byte ) ( 0x04 ) ); // OCTET STRING tag
978                 inner.put( authInfoEncodedLength );
979                 inner.put( authInfo );
980             }
981 
982             if ( pubKeyAuth != null )
983             {
984                 final byte[] pubKeyAuthEncodedLength = encodeLength( pubKeyAuth.length );
985 
986                 inner.put( ( byte ) ( 0x03 | 0xa0 ) ); // pubKeyAuth tag [3]
987                 inner.put( encodeLength( 1 + pubKeyAuthEncodedLength.length + pubKeyAuth.length ) ); // length
988 
989                 inner.put( ( byte ) ( 0x04 ) ); // OCTET STRING tag
990                 inner.put( pubKeyAuthEncodedLength );
991                 inner.put( pubKeyAuth );
992             }
993 
994             inner.flip();
995 
996             // SEQUENCE tag
997             buf.put( ( byte ) ( 0x10 | 0x20 ) );
998             buf.put( encodeLength( inner.limit() ) );
999             buf.put( inner );
1000         }
1001 
1002 
1003         public String debugDump()
1004         {
1005             final StringBuilder sb = new StringBuilder( "TsRequest\n" );
1006             sb.append( "  negoToken:\n" );
1007             sb.append( "    " );
1008             DebugUtil.dump( sb, negoToken );
1009             sb.append( "\n" );
1010             sb.append( "  authInfo:\n" );
1011             sb.append( "    " );
1012             DebugUtil.dump( sb, authInfo );
1013             sb.append( "\n" );
1014             sb.append( "  pubKeyAuth:\n" );
1015             sb.append( "    " );
1016             DebugUtil.dump( sb, pubKeyAuth );
1017             return sb.toString();
1018         }
1019 
1020 
1021         @Override
1022         public String toString()
1023         {
1024             return "TsRequest(negoToken=" + Arrays.toString( negoToken ) + ", authInfo="
1025                 + Arrays.toString( authInfo ) + ", pubKeyAuth=" + Arrays.toString( pubKeyAuth ) + ")";
1026         }
1027     }
1028 
1029     static void getByteAndAssert( final ByteBuffer buf, final int expectedValue, final String errorMessage )
1030         throws MalformedChallengeException
1031     {
1032         final byte bufByte = buf.get();
1033         if ( bufByte != expectedValue )
1034         {
1035             parseError( buf, errorMessage + expectMessage( expectedValue, bufByte ) );
1036         }
1037     }
1038 
1039     private static String expectMessage( final int expectedValue, final int realValue )
1040     {
1041         return "(expected " + String.format( "%02X", expectedValue ) + ", got " + String.format( "%02X", realValue )
1042             + ")";
1043     }
1044 
1045     static int parseLength( final ByteBuffer buf )
1046     {
1047         byte bufByte = buf.get();
1048         if ( bufByte == 0x80 )
1049         {
1050             return -1; // infinite
1051         }
1052         if ( ( bufByte & 0x80 ) == 0x80 )
1053         {
1054             final int size = bufByte & 0x7f;
1055             int length = 0;
1056             for ( int i = 0; i < size; i++ )
1057             {
1058                 bufByte = buf.get();
1059                 length = ( length << 8 ) + ( bufByte & 0xff );
1060             }
1061             return length;
1062         }
1063         else
1064         {
1065             return bufByte;
1066         }
1067     }
1068 
1069     static void getLengthAndAssert( final ByteBuffer buf, final int expectedValue, final String errorMessage )
1070         throws MalformedChallengeException
1071     {
1072         final int bufLength = parseLength( buf );
1073         if ( expectedValue != bufLength )
1074         {
1075             parseError( buf, errorMessage + expectMessage( expectedValue, bufLength ) );
1076         }
1077     }
1078 
1079     static int getAndAssertContentSpecificTag( final ByteBuffer buf, final String errorMessage ) throws MalformedChallengeException
1080     {
1081         final byte bufByte = buf.get();
1082         if ( ( bufByte & 0xe0 ) != 0xa0 )
1083         {
1084             parseError( buf, errorMessage + ": wrong content-specific tag " + String.format( "%02X", bufByte ) );
1085         }
1086         final int tag = bufByte & 0x1f;
1087         return tag;
1088     }
1089 
1090     static void parseError( final ByteBuffer buf, final String errorMessage ) throws MalformedChallengeException
1091     {
1092         throw new MalformedChallengeException(
1093             "Error parsing TsRequest (position:" + buf.position() + "): " + errorMessage );
1094     }
1095 
1096     static byte[] encodeLength( final int length )
1097     {
1098         if ( length < 128 )
1099         {
1100             final byte[] encoded = new byte[1];
1101             encoded[0] = ( byte ) length;
1102             return encoded;
1103         }
1104 
1105         int size = 1;
1106 
1107         int val = length;
1108         while ( ( val >>>= 8 ) != 0 )
1109         {
1110             size++;
1111         }
1112 
1113         final byte[] encoded = new byte[1 + size];
1114         encoded[0] = ( byte ) ( size | 0x80 );
1115 
1116         int shift = ( size - 1 ) * 8;
1117         for ( int i = 0; i < size; i++ )
1118         {
1119             encoded[i + 1] = ( byte ) ( length >> shift );
1120             shift -= 8;
1121         }
1122 
1123         return encoded;
1124     }
1125 
1126 }